From c9fec32d6ef2d53a0d02f5b0ca39d87e184e4447 Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Thu, 15 May 2025 22:50:34 +0000 Subject: [PATCH 001/102] Port ROCm changes from multi-backend-refactor branch --- CMakeLists.txt | 80 +- bitsandbytes/cextension.py | 82 +- bitsandbytes/diagnostics/cuda.py | 70 +- bitsandbytes/diagnostics/main.py | 26 +- csrc/common_hip.cuh | 7 + csrc/kernels.hip | 3253 ++++++++++++++++++++++++++++++ csrc/kernels_hip.cuh | 132 ++ csrc/ops.hip | 836 ++++++++ csrc/ops_hip.cuh | 195 ++ csrc/pythonInterface.cpp | 22 +- 10 files changed, 4654 insertions(+), 49 deletions(-) create mode 100644 csrc/common_hip.cuh create mode 100644 csrc/kernels.hip create mode 100644 csrc/kernels_hip.cuh create mode 100644 csrc/ops.hip create mode 100644 csrc/ops_hip.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b462c45d..8a7583279 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,13 +25,14 @@ endif() # Define included source files set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) +set(HIP_FILES csrc/ops.hip csrc/kernels.hip) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -47,15 +48,25 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") message(FATAL_ERROR "CUDA is not supported on macOS" ) endif() set(BUILD_CUDA ON) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) +elseif(${COMPUTE_BACKEND} STREQUAL "hip") + if(APPLE) + message(FATAL_ERROR "HIP is not supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_HIP ON) set(BUILD_MPS OFF) elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) endif() set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) set(BUILD_MPS ON) else() set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) set(BUILD_MPS OFF) endif() @@ -160,6 +171,36 @@ if(BUILD_CUDA) string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") add_compile_definitions(BUILD_CUDA) +elseif(BUILD_HIP) + enable_language(HIP) + message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}") + if(DEFINED BNB_ROCM_ARCH) + set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH}) + else() + if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100") + elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) + endif() + endif() + message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}") + + list(APPEND SRC_FILES ${HIP_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_rocm") + + # get hip version + execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION) + string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}") + string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}") + + string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}") + if(HIP_VERSION VERSION_LESS "6.1") + string(APPEND BNB_OUTPUT_NAME "_nohipblaslt") + endif() + add_compile_definitions(__HIP_PLATFORM_AMD__) + add_compile_definitions(__HIP_PLATFORM_HCC__) + add_compile_definitions(BUILD_HIP) elseif(BUILD_MPS) if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) @@ -208,6 +249,41 @@ if(BUILD_CUDA) CUDA_SEPARABLE_COMPILATION ON ) endif() +if(BUILD_HIP) + if(NOT DEFINED ENV{ROCM_PATH}) + set(ROCM_PATH /opt/rocm) + else() + set(ROCM_PATH $ENV{ROCM_PATH}) + endif() + list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) + macro(find_package_and_print_version PACKAGE_NAME) + find_package("${PACKAGE_NAME}" ${ARGN}) + message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") + endmacro() + find_package_and_print_version(hipblas REQUIRED) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(hipsparse REQUIRED) + + ## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies) + set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "") + + target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include) + target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib) + target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse) + + target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP) + set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP) + set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX) + + if(HIP_VERSION VERSION_LESS "6.1") + target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT) + else() + find_package(hipblaslt) + target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt) + endif() +endif() if(BUILD_MPS) add_dependencies(bitsandbytes metallib) target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 3fb8db26f..c8b02fb22 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -22,11 +22,17 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: """ prefix = "rocm" if torch.version.hip else "cuda" - library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" + blas_suffix = "_nohipblaslt" if torch.version.hip and cuda_specs.cuda_version_tuple < (6, 1) else "" + library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{blas_suffix}{DYNAMIC_LIBRARY_SUFFIX}" override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) + if torch.version.hip: + raise RuntimeError( + f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" + f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" + ) logger.warning( f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" @@ -72,10 +78,11 @@ def __init__(self, lib: ct.CDLL): def get_available_cuda_binary_versions() -> list[str]: """Get formatted CUDA versions from existing library files using cuda_specs logic""" - lib_pattern = f"libbitsandbytes_cuda*{DYNAMIC_LIBRARY_SUFFIX}" + lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}" versions = [] for lib in Path(__file__).parent.glob(lib_pattern): - match = re.search(r"cuda(\d{3})", lib.name) + pattern = r"{}(\d+)".format(BNB_BACKEND.lower()) + match = re.search(pattern, lib.name) if match: ver_code = int(match.group(1)) major = ver_code // 10 @@ -86,8 +93,8 @@ def get_available_cuda_binary_versions() -> list[str]: def parse_cuda_version(version_str: str) -> str: """Convert raw version string (e.g. '118' from env var) to formatted version (e.g. '11.8')""" - if version_str.isdigit() and len(version_str) == 3: - return f"{version_str[:2]}.{version_str[2]}" + if version_str.isdigit(): + return f"{version_str[:-1]}.{version_str[-1]}" return version_str # fallback as safety net @@ -148,7 +155,7 @@ def _format_lib_error_message( """Format detailed error message for library loading failures""" analysis = "" no_cpu_lib_found = "libbitsandbytes_cpu.so: cannot open" in original_error - no_cuda_lib_found = "CUDA binary not found" in original_error + no_cuda_lib_found = f"{BNB_BACKEND} binary not found" in original_error if no_cpu_lib_found: analysis = "\n🚨 Failed to load CPU-only bitsandbytes library 🚨\n\n" @@ -157,9 +164,9 @@ def _format_lib_error_message( version_list_str = "\n - " + "\n - ".join(available_versions) if available_versions else "NONE" analysis = ( ( - f"\n🚨 CUDA VERSION MISMATCH 🚨\n" - f"Requested CUDA version: {requested_version}\n" - f"Detected PyTorch CUDA version: {user_cuda_version}\n" + f"\n🚨 {BNB_BACKEND} VERSION MISMATCH 🚨\n" + f"Requested {BNB_BACKEND} version: {requested_version}\n" + f"Detected PyTorch {BNB_BACKEND} version: {user_cuda_version}\n" f"Available pre-compiled versions: {version_list_str}\n\n" "This means:\n" "The version you're trying to use is NOT distributed with this package\n\n" @@ -174,42 +181,49 @@ def _format_lib_error_message( troubleshooting = ( ( - "This typically happens when:\n" - "1. bitsandbytes doesn't ship with a pre-compiled binary for your CUDA version\n" - "2. The library wasn't compiled properly during installation from source\n\n" + f"This typically happens when:\n" + f"1. bitsandbytes doesn't ship with a pre-compiled binary for your {BNB_BACKEND} version\n" + f"2. The library wasn't compiled properly during installation from source\n\n" ) if no_cuda_lib_found - else "This typically happens when you checked the code out from source and your torch installation doesn't detect CUDA on your machine.\n\n" + else f"This typically happens when you checked the code out from source and your torch installation doesn't detect {BNB_BACKEND} on your machine.\n\n" ) note = ( ( - "To make bitsandbytes work, the compiled library version MUST exactly match the linked CUDA version.\n" - "If your CUDA version doesn't have a pre-compiled binary, you MUST compile from source.\n\n" + f"To make bitsandbytes work, the compiled library version MUST exactly match the linked {BNB_BACKEND} version.\n" + f"If your {BNB_BACKEND} version doesn't have a pre-compiled binary, you MUST compile from source.\n\n" ) if no_cuda_lib_found else "" ) compile_instructions = ( + ( + "COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n" + ) if not no_cuda_lib_found + else ( "You have two options:\n" "1. COMPILE FROM SOURCE (required if no binary exists):\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation#cuda-compile\n" "2. Use BNB_CUDA_VERSION to specify a DIFFERENT CUDA version from the detected one, which is installed on your machine and matching an available pre-compiled version listed above\n\n" + ) if not HIP_ENVIRONMENT + else + ( + "You can COMPILE FROM SOURCE as mentioned here:\n" + " https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n" ) - if no_cuda_lib_found - else "COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n" ) diagnostics = ( - "šŸ” Run this command for detailed diagnostics:\n" - "python -m bitsandbytes\n\n" - "If you've tried everything and still have issues:\n" - "1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n" - "2. Describe what you've tried in detail\n" - "3. Open an issue with this information:\n" - " https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" + f"šŸ” Run this command for detailed diagnostics:\n" + f"python -m bitsandbytes\n\n" + f"If you've tried everything and still have issues:\n" + f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n" + f"2. Describe what you've tried in detail\n" + f"3. Open an issue with this information:\n" + f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" ) return f"{analysis}{base_msg}{troubleshooting}{note}{compile_instructions}{original_error}\n{diagnostics}" @@ -224,18 +238,19 @@ def _format_dependency_error(self) -> str: ) return ( - f"\n🚨 CUDA SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n" - f"CUDA {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n" + f"\n🚨 {BNB_BACKEND} SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n" + f"{BNB_BACKEND} {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n" f"To fix this, make sure that:\n" - f"1. You have installed CUDA {cuda_major_version}.x toolkit on your system\n" - f"2. The CUDA runtime libraries are in your LD_LIBRARY_PATH\n\n" + f"1. You have installed {BNB_BACKEND} {cuda_major_version}.x toolkit on your system\n" + f"2. The {BNB_BACKEND} runtime libraries are in your LD_LIBRARY_PATH\n\n" f"You can add them with (and persist the change by adding the line to your .bashrc):\n" - f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/cuda-{cuda_major_version}.x/lib64\n\n" + f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/{BNB_BACKEND.lower()}-{cuda_major_version}.x/\ + {'lib64' if not HIP_ENVIRONMENT else 'lib'}\n\n" f"Original error: {self.error_msg}\n\n" f"šŸ” Run this command for detailed diagnostics:\n" f"python -m bitsandbytes\n\n" f"If you've tried everything and still have issues:\n" - f"1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n" + f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n" f"2. Describe what you've tried in detail\n" f"3. Open an issue with this information:\n" f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" @@ -264,7 +279,7 @@ def get_native_library() -> BNBNativeLibrary: cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) if not cuda_binary_path.exists(): - raise RuntimeError(f"Configured CUDA binary not found at {cuda_binary_path}") + raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}") binary_path = cuda_binary_path @@ -284,6 +299,11 @@ def get_native_library() -> BNBNativeLibrary: try: + if torch.version.hip: + HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" + else: + HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" + lib = get_native_library() except Exception as e: error_msg = str(e) diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index affcb0ae6..b9de27fd7 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -5,7 +5,7 @@ import torch -from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path from bitsandbytes.consts import NONPYTORCH_DOC_URL from bitsandbytes.cuda_specs import CUDASpecs from bitsandbytes.diagnostics.utils import print_dedented @@ -33,6 +33,8 @@ } CUDA_RUNTIME_LIB_PATTERNS = ( + "libamdhip64.so*", +) if HIP_ENVIRONMENT else ( "cudart64*.dll", # Windows "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. "nvcuda*.dll", # Windows @@ -57,7 +59,7 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path pass for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS: for pth in dir.glob(lib_pattern): - if pth.is_file(): + if pth.is_file() and not pth.is_symlink(): yield pth except (OSError, PermissionError): pass @@ -104,7 +106,7 @@ def find_cudart_libraries() -> Iterator[Path]: yield from find_cuda_libraries_in_path_list(value) -def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: +def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: print( f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", @@ -149,7 +151,37 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: # (2) Multiple CUDA versions installed -def print_cuda_runtime_diagnostics() -> None: +def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: + print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") + + binary_path = get_cuda_bnb_library_path(cuda_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. + Maybe you need to compile it from source? If you compiled from source, check that ROCm version + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version + and rebuild bitsandbytes. + """, + ) + + hip_major, hip_minor = cuda_specs.cuda_version_tuple + if (hip_major, hip_minor) < (6, 1): + print_dedented( + """ + WARNING: bitsandbytes is fully supported only from ROCm 6.1. + """, + ) + + +def print_diagnostics(cuda_specs: CUDASpecs) -> None: + if HIP_ENVIRONMENT: + _print_hip_diagnostics(cuda_specs) + else: + _print_cuda_diagnostics(cuda_specs) + + +def _print_cuda_runtime_diagnostics() -> None: cudart_paths = list(find_cudart_libraries()) if not cudart_paths: print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") @@ -174,3 +206,33 @@ def print_cuda_runtime_diagnostics() -> None: ) for pth in cudart_paths: print(f"* Found CUDA runtime at: {pth}") + + +def _print_hip_runtime_diagnostics() -> None: + cudart_paths = list(find_cudart_libraries()) + if not cudart_paths: + print("WARNING! ROCm runtime files not found in any environmental path.") + elif len(cudart_paths) > 1: + print_dedented( + f""" + Found duplicate ROCm runtime files (see below). + + We select the PyTorch default ROCm runtime, which is {torch.version.hip}, + but this might mismatch with the ROCm version that is needed for bitsandbytes. + + To resolve it, install PyTorch built for the ROCm version you want to use + + and set LD_LIBRARY_PATH to your ROCm install path, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib, + """, + ) + + for pth in cudart_paths: + print(f"* Found ROCm runtime at: {pth}") + + +def print_runtime_diagnostics() -> None: + if HIP_ENVIRONMENT: + _print_hip_runtime_diagnostics() + else: + _print_cuda_runtime_diagnostics() diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index b6236d668..8e2bc2a7b 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -3,11 +3,12 @@ import torch +from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT from bitsandbytes.consts import PACKAGE_GITHUB_URL from bitsandbytes.cuda_specs import get_cuda_specs from bitsandbytes.diagnostics.cuda import ( - print_cuda_diagnostics, - print_cuda_runtime_diagnostics, + print_diagnostics, + print_runtime_diagnostics, ) from bitsandbytes.diagnostics.utils import print_dedented, print_header @@ -34,19 +35,24 @@ def main(): print_header("OTHER") cuda_specs = get_cuda_specs() - print("CUDA specs:", cuda_specs) + if HIP_ENVIRONMENT: + rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," + rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" + print(f"{BNB_BACKEND} specs:{rocm_specs}") + else: + print(f"{BNB_BACKEND} specs:{cuda_specs}") if not torch.cuda.is_available(): - print("Torch says CUDA is not available. Possible reasons:") - print("1. CUDA driver not installed") - print("2. CUDA not installed") - print("3. You have multiple conflicting CUDA libraries") + print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") + if not HIP_ENVIRONMENT: print(f"- {BNB_BACKEND} driver not installed") + print(f"- {BNB_BACKEND} not installed") + print(f"- You have multiple conflicting {BNB_BACKEND} libraries") if cuda_specs: - print_cuda_diagnostics(cuda_specs) - print_cuda_runtime_diagnostics() + print_diagnostics(cuda_specs) + print_runtime_diagnostics() print_header("") print_header("DEBUG INFO END") print_header("") - print("Checking that the library is importable and CUDA is callable...") + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") try: sanity_check() print("SUCCESS!") diff --git a/csrc/common_hip.cuh b/csrc/common_hip.cuh new file mode 100644 index 000000000..e7fc4eb81 --- /dev/null +++ b/csrc/common_hip.cuh @@ -0,0 +1,7 @@ +#pragma once + +#define BNB_WARP_SIZE warpSize + +// These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs +#define BNB_MAX_THREADS_PER_SM 2048 +#define BNB_BF16_AVAILABLE true diff --git a/csrc/kernels.hip b/csrc/kernels.hip new file mode 100644 index 000000000..368788f39 --- /dev/null +++ b/csrc/kernels.hip @@ -0,0 +1,3253 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "kernels_hip.cuh" +#include "common_hip.cuh" +#include +#include +#include + +//#include + + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; + +// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +// Luckily we have atomicmax and atomicmin in ROCm + + +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + +__device__ unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assume input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to notice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + + +__device__ __forceinline__ float dDequantizeNF4(unsigned char val) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ unsigned char dQuantizeNF4(float x) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1000; + else + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; + else + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; + else + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; +} +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template __device__ int sgn(T val) +{ + return (T(0) < val) - (val < T(0)); +} + +template +__device__ unsigned char dQuantize(float* smem_code, const float rand, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = smem_code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(upper_pivot == 255) + upper = smem_code[upper_pivot]; + if(lower_pivot == 0) + lower = smem_code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = fabsf(upper-x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = fabsf(lower-x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +template +__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = SIGNED ? -1.0f : 0.0f; + float upper = 1.0f; + float midpoint; + float val = quadrants[1]; + int local_pivot = 1; + int offset = 1; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + //val = i == 64 ? quadrants[2] : smem_code[pivot]; + local_pivot += offset; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + //val = i == 64 ? quadrants[0] : smem_code[pivot]; + local_pivot -= offset; + } + val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; + offset -= 1; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) +{ + const int tid = threadIdx.x + (blockDim.x*blockIdx.x); + const int numThreads = blockDim.x*gridDim.x; + + for(int i = tid; i < n; i+=numThreads) + { + int idx = (index1[i]*maxidx1) + index2[i]; + atomicAdd(&histogram[idx], src[i]); + } +} + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 + +template +__launch_bounds__(THREADS_ESTIMATE, 1) +__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n) +{ + const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; + const int base_idx = (blockIdx.x * BLOCK_ESTIMATE); + const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); + + T vals[NUM_ESTIMATE]; + + typedef hipcub::BlockRadixSort BlockRadixSort; + typedef hipcub::BlockLoad LoadFloat; + + __shared__ union { + typename LoadFloat::TempStorage loadf; + typename BlockRadixSort::TempStorage sort; + int smem_qidx[BLOCK_ESTIMATE]; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE) + { + valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; + + // do not process half-blocks + if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = max_val; + + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = ((float)vals[j]) * reciprocal_num_blocks; + + + __syncthreads(); + // sort into striped pattern to mitigate bank conflicts + // striped pattern index for thread 0 [0, 1024, 2048, 3096] + // striped pattern index for thread 1 [1, 1025, 2049, 3097] + BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); + + __syncthreads(); + for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) + temp_storage.smem_qidx[j] = -1; + + __syncthreads(); + + if(threadIdx.x < 256) + { + float q_interval = (1.0f-(2.0f*offset))/255.0f; + int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1))); + temp_storage.smem_qidx[local_idx] = threadIdx.x; + } + + __syncthreads(); + + for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x) + { + if(temp_storage.smem_qidx[i] != -1) + atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); + } + } +} + + +__launch_bounds__(TH, 4) +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) +{ + const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK; + const int base_idx = (blockIdx.x * NUM_BLOCK); + + float vals[NUM]; + unsigned char qvals[NUM]; + //const int lane_id = threadIdx.x % 2; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreChar; + + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ float smem_code[256]; + //__shared__ float smem_code[2][257]; + + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + //smem_code[0][threadIdx.x] = code[threadIdx.x]; + //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK) + { + // number of values already processed in blocks + + // number of values already processed in this block + + // rand_offset % mod value + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + __syncthreads(); + LoadFloat(loadf).Load(&(A[i]), vals, valid_items); + + + #pragma unroll 4 + for(int j = 0; j < NUM; j++) + qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); + + __syncthreads(); + StoreChar(storec).Store(&(out[i]), qvals, valid_items); + } +} + +template +//__launch_bounds__(TH, 4) +__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) +{ + const int n_full = gridDim.x * BLOCK_SIZE; + int valid_items = 0; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + //float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; + int local_rand_idx = 0; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadFloat; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename BlockReduce::TempStorage reduce; + __shared__ float smem_code[256]; + __shared__ float smem_absmax_value[1]; + + if(DATA_TYPE == General8bit) + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; + + for (int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_abs_max = -FLT_MAX; + + __syncthreads(); + LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, hipcub::Max(), valid_items); + + if(threadIdx.x == 0) { + smem_absmax_value[0] = 1.0f / local_abs_max; + absmax[i / BLOCK_SIZE] = local_abs_max; + } + __syncthreads(); + + local_abs_max = smem_absmax_value[0]; + + if(STOCHASTIC) + { + local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4); + LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + } + + unsigned char packed_4bit = 0; + switch(DATA_TYPE) + { + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + } + + __syncthreads(); + StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); + } +} + +template +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) +{ + + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + typedef hipcub::BlockLoad LoadChar; + typedef hipcub::BlockStore 0) ? 2 : 1), hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) + { + if (DATA_TYPE > 0) + { + valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i); + valid_items_store = min(TILE_SIZE * 2, n - i * 2); + } + else + { + valid_items_load = min(TILE_SIZE, n - i); + valid_items_store = valid_items_load; + } + + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]); + + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + + switch (DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + break; + } + + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); + } +} + +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) +{ + const unsigned int numThreads = blockDim.x * gridDim.x; + const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; + + __shared__ float smem_code[256]; + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + } + + __syncthreads(); + + for (int i = idx;i < n; i += numThreads) + { + out[i] = smem_code[A[i]]; + } +} + + + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + float s2_vals[NUM_VALS]; + + const float correction1 = 1.0f/(1.0f - powf(beta1, step)); + const float correction2 = 1.0f/(1.0f - powf(beta2, step)); + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case ADAM: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) + break; + } + } + + # pragma unroll NUM_VALS-1 + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + //__syncwarp(); + } +} + + + +#define NUM_PER_THREAD 4 + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + + float s1_vals[NUM_PER_THREAD]; + float s2_vals[NUM_PER_THREAD]; + + // AdEMAMix has an additional state buffer, which we packed + // into state1. We need thread-local storage here for these. + // TODO: Mark with [[maybe_unused]] after upgrade to min compiler. + float s3_vals[NUM_PER_THREAD]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + // Load additional state1 data for AdEMAMix + // TODO: Make constexpr after updating min compiler + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + switch(OPTIMIZER) + { + case ADEMAMIX: + // m1 update: m1 = beta1 * m1 + (1-beta1) * g + s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]); + + // m2 update: m2 = m2 * beta3 + (1-beta3) * g + s3_vals[j] = (s3_vals[j] * beta3) + ((1.0f - beta3) * (float)g_vals[j]); + + // nu update: nu = beta2 * nu + (1-beta2) * g^2 + s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]); + + p_vals[j] = (float)p_vals[j] - lr * ( + ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( + (sqrtf(s2_vals[j]) / correction2) + eps + ) + ); + + if (weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); + + break; + case ADAM: + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + break; + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items); + } + } +} + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; // state update + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + } + } + + # pragma unroll + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + //__syncwarp(); + } +} + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit1State(T *g, T *p, + float *state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + g_vals[j] = gnorm_scale*((float)g_vals[j]); + if(weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); + break; + } + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + } +} + + +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_max_s2 = -FLT_MAX; + float local_unorm = 0.0f; + + float s2_vals[NUM8BIT]; + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + unsigned char r_c2[NUM8BIT]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + if(threadIdx.x < 256) + { + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); + __syncthreads(); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; + s1_vals[j] += (1.0f-beta1)*g_val; + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; + s2_vals[j] += (1.0f-beta2)*g_val*g_val; + local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); + } + + if(unorm != NULL) + { + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); + float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + local_unorm += update_val*update_val; + } + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); + __syncthreads(); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, hipcub::Max(), valid_items); + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); + } + + if(threadIdx.x == 0) + { + atomicMax(&new_max1[0], local_max_s1); + atomicMax(&new_max2[0], local_max_s2); + if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); } + } +} + +#define NUM_PER_THREAD2 4 +#define NUM_THREADS2 1024 +#define NUM_PER_BLOCK2 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS2, 1) +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float s2_vals[NUM_PER_THREAD2]; + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + //const float step_size = -lr*correction2/correction1; + float new_max_val1 = 1.0f/new_max1[0]; + float new_max_val2 = 1.0f/new_max2[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + unsigned char c2s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 512) + { + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + else + smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[c1s[j]]; + s1_vals[j] = s1_vals[j]*max1[0]; + + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + s2_vals[j] = smem_quantiles2[c2s[j]]; + s2_vals[j] = s2_vals[j]*max2[0]; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_unorm = 0.0f; + + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; + switch(OPTIMIZER) + { + case ADAGRAD: + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + if(unorm != NULL) + local_unorm += s1_vals[j]*s1_vals[j]; + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + } + + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); + if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); + if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } + } + +} + +template +__global__ void +__launch_bounds__(1024, 1) +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float new_max_val1 = 1.0f/new_max1[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case ADAGRAD: + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; + + switch(OPTIMIZER){ + case ADAGRAD: + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); + break; + } + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n) +{ + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + int valid_items = 0; + + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadT; + + __shared__ typename BlockReduce::TempStorage reduce; + + __shared__ typename LoadT::TempStorage loadT; + T vals[NUM_VALS]; + float local_sum = 0.0f; + + for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_sum = 0.0f; + + __syncthreads(); + LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); + + #pragma unroll NUM_VALS + for(int j = 0; j < NUM_VALS; j++) + local_sum += ((float)vals[j])*((float)vals[j]); + + local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + if(threadIdx.x == 0) + { + if(step == 1) + { + // initialize with the same norm for all positions + //#pragma unroll 10 + for(int j = 0; j < 100; j++) + atomicAdd(&gnorm_vec[j], local_sum); + } + else + atomicAdd(&gnorm_vec[step % 100], local_sum); + } + + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit2StateBlockwise( + T* p, + T* __restrict__ const g, + unsigned char* state1, + unsigned char* state2, + const float beta1, + const float beta2, + const float beta3, + const float alpha, + const float eps, + const int step, + const float lr, + float* __restrict__ const quantiles1, + float* __restrict__ const quantiles2, + float* absmax1, + float* absmax2, + float weight_decay, + const float gnorm_scale, + const bool skip_zeros, + const int n +) { + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + float s2_vals[N_PER_TH]; + float s3_vals[N_PER_TH]; + + // 2-5% + const float correction1 = 1.0f - __powf(beta1, step); + const float correction2 = sqrtf(1.0f -__powf(beta2, step)); + const float step_size = __fdividef(-lr*correction2,correction1); + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float new_local_abs_max2 = -FLT_MAX; + float new_local_abs_max3 = -FLT_MAX; + float quadrants1[QUAD]; + float quadrants2[QUAD]; + + unsigned char c1s[N_PER_TH]; + unsigned char c2s[N_PER_TH]; + unsigned char c3s[N_PER_TH]; + + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + __shared__ float smem_quantiles2[LANES][257]; + typedef hipcub::BlockReduce BlockReduce1; + typedef hipcub::BlockReduce BlockReduce2; + typedef hipcub::BlockReduce BlockReduce3; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ typename BlockReduce2::TempStorage reduce2; + __shared__ typename BlockReduce2::TempStorage reduce3; + __shared__ float smem_exchange1[1]; + __shared__ float smem_exchange2[1]; + __shared__ float smem_exchange3[1]; // [[maybe_unused]] + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + { + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; + } + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + { + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + + // AdEMAMix has an additional state packed into state1. + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[n + i]), c3s, valid_items, 128); + } + + new_local_abs_max1 = -FLT_MAX; + new_local_abs_max2 = -FLT_MAX; + new_local_abs_max3 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + if (OPTIMIZER == ADEMAMIX) { + // The absmax for the third state is appended to absmax1 + s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i)/BLOCK_SIZE]; + s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val)); + } + } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + + if (OPTIMIZER == ADEMAMIX) { + s3_vals[j] = 0.0f; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); + } + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, hipcub::Max()); + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, hipcub::Max()); + } + + if(threadIdx.x == 0) + { + smem_exchange1[0] = new_local_abs_max1; + smem_exchange2[0] = new_local_abs_max2; + + if (OPTIMIZER == ADEMAMIX) { + smem_exchange3[0] = new_local_abs_max3; + } + } + + __syncthreads(); + + if(threadIdx.x == 0) + { + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + absmax2[i/BLOCK_SIZE] = new_local_abs_max2; + + if (OPTIMIZER == ADEMAMIX) { + absmax1[(n + i)/BLOCK_SIZE] = new_local_abs_max3; + } + } + else + { + new_local_abs_max1 = smem_exchange1[0]; + new_local_abs_max2 = smem_exchange2[0]; + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = smem_exchange3[0]; + } + } + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + if (OPTIMIZER == ADEMAMIX) { + p_vals[j] = T((float)p_vals[j] - lr * ( + ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( + (sqrtf(s2_vals[j]) / correction2) + eps + ) + )); + } else { + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + } + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + if (OPTIMIZER == ADEMAMIX) { + c3s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j],new_local_abs_max3)); + + if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) { + c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1; + } + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[n + i]), c3s, valid_items); + } + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + // 2-5% + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float quadrants1[QUAD]; + + unsigned char c1s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + typedef hipcub::BlockReduce BlockReduce1; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ float smem_exchange1[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + + new_local_abs_max1 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j]*beta1) + g_val; + break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val*g_val); + break; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + + if(threadIdx.x == 0) + smem_exchange1[0] = new_local_abs_max1; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + else + new_local_abs_max1 = smem_exchange1[0]; + + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + } + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + } +} + +// Inputs: +// A [rows, cols] +// Outputs: +// rowStats [rows] +// out [rows, cols] +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { + + // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. + // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped. +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE + using TReduction = T; +#else + using TReduction = float; +#endif + + using BlockReduceT = hipcub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + __shared__ TReduction smem_row_absmax; + + const int row_id = blockIdx.x; + const T* row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + TReduction row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const TReduction absval = fabsf(__ldcs(&(row_data[i]))); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } + + // Reduce thread-local absmax across the block. + const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = smem_row_absmax = row_absmax; + } + __syncthreads(); + + // Quantize row-wise. + const float scale = __fdividef(127.0f, smem_row_absmax); + for (int i = threadIdx.x; i < cols; i += THREADS) { + float val = row_data[i]; + + if constexpr (SPARSE_DECOMP) { + // For sparse decomposition, we do not want to quantize the outliers. + // Instead they're zeroed out. + out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; + } else { + out[row_id * cols + i] = __float2int_rn(val * scale); + } + } +} + +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) { + using BlockReduceT = hipcub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + + const int row_id = blockIdx.x; + const T* __restrict__ row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + float row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const float absval = fabsf(row_data[i]); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } + + // Reduce thread-local absmax across the block. + // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY + const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = row_absmax; + } +} + +template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); +template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); + +template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); +template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); + + +#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) + +template +__global__ void kdequant_mm_int32_fp16( + int* __restrict__ const A, + float *__restrict__ const rowStats, + float *__restrict__ const colStats, + half *out, + half *__restrict__ const bias, + const int numRows, + const int numCols, + const int n +) { + const int n_out = numRows * numCols; + + int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; + int thread_offset = threadIdx.x * ITEMS_PER_THREAD; + + int local_values[ITEMS_PER_THREAD]; + half local_output[ITEMS_PER_THREAD]; + + float local_rowStats[ITEMS_PER_THREAD]; + float local_colStats[ITEMS_PER_THREAD]; + float local_biasValue[ITEMS_PER_THREAD]; + + typedef hipcub::BlockLoad LoadInt32; + __shared__ typename LoadInt32::TempStorage loadint32; + + int row_idx, col_idx; + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + row_idx = (block_offset + thread_offset + j) / numCols; + col_idx = (block_offset + thread_offset + j) % numCols; + + local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; + local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + local_biasValue[j] = ((bias == nullptr) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); + } + + // Each block loads THREADS * ITEMS_PER_THREAD values from A + int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out + ? THREADS * ITEMS_PER_THREAD + : n_out - block_offset; + LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); + + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; ++j) { + local_output[j] = __float2half( + fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) + ); + } + + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; j++) { + int outIdx = block_offset + thread_offset + j; + if (outIdx < n_out) { + out[outIdx] = local_output[j]; + } + } +} + +#define DENORM 1.0f/127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8*256 +#define WARP_SIZE warpSize +template +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +{ + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[blockIdx.x]; + const int local_max_idx = max_idx[blockIdx.x]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int warp_idx = threadIdx.x % WARP_SIZE; + const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for(int j = 0; j < MAX_SPARSE_COUNT; j++) + { + local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); + local_colidxA[j] = j < count ? colidx[offset+j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*WARP_SIZE apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + __shared__ half smem_dequant_stats[SMEM_SIZE]; + + + while(idx_col_B < colsB) + { + + if(dequant_stats != NULL) + { + for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) + if((idx_col_B+i-local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; + + __syncthreads(); + } + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + + #pragma unroll + for(int i = 0; i < count; i++) + { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB*local_colidxA[i]; + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; + if(idx >= colsB){ break; } + if((idx+num_items < colsB)) + { + if(BITS == 8) + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + else + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx+k < colsB) + local_valsB[k] = B[row_offset+idx+k]; + else + local_valsB[k] = 0.0f; + } + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + { + if(BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if(valB != 0.0 && valA != 0.0) + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; + } + else + local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB*local_row_idx); + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + //int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if(idx_col_C +num_items < colsB) + { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; + + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx_col_C + k < colsB) + out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + } + } + + idx_col_B += blockDim.x*SPMM_ITEMS; + local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; + } +} + +#define WARPS 3 +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; + rocwmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + //local_A[0] = A[idx]; + + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, rocwmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif +} + + +template __device__ void printnonzero(T *A, int num_values, const char * strval) +{ + for(int i = 0; i < num_values; i++) + if((float)A[i] != 0.0) + printf("%s %i %f\n", strval, i, (float)A[i]); +} + +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + //// element-wise kernel + //// 1. Load batch x k into registers + //// 2. Load k x k into registers + //// 3. dequantize and store in second pair of k x k + //// 4. matmul + //// 5. sum with cub + //// 6. store outputs + //// TC kernel + //// use k warps per thread block + //// 1. threadblock use read-only cache to read in register tile for A into shared memory + //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments + //// 3. each warp reads a segment of values 16x32 from B + //// 4. do dequantization from register of B into second pair of registers + //// 5. store (4) into fragment + //// 6. matmul aggregate into fragment C + //// 7. aggregate files of C into shared memory block C + //// 8. sum (7) + //// 9. write outputs to matmul output matrix +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + + T quant_map[16]; + + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + //__shared__ T quant_map[16*160]; + + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_C[8*32]; + + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; + rocwmma::fill_fragment(c_frag, 0.0f); + + for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) + smem_C[i] = 0.0f; + + __syncthreads(); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); + local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); + local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + //if(threadIdx.x == 0) + //printf("aa %i %i\n", idx, loaded_values); + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + //if(threadIdx.x == 0) + //printf("%i %i\n", idx, loaded_values); + + //__syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); + + //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); + //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); + local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); + } + //printnonzero(local_B, 128, ""); + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + //if(threadIdx.x == 0) + //{ + // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); + // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); + //} + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + //if(warp_lane == 0) + //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, rocwmma::mem_row_major); + + //printnonzero(smem_C, 32, ""); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_C[warp_lane]; +#endif +} + +// No of 4bit values processed by each thread +#define num_values_4bit 32 +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + // per threadblock: + // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize]; + + const int warp_idx = threadIdx.x / warpSize; + const int warp_lane = threadIdx.x % warpSize; + const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx; + const int offset_B = ldb*row_B; + const int num_values_8bit = num_values_4bit/2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit/4]; + T local_A[num_values_4bit/4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + if (threadIdx.x < 16) + quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); + //for(int i = threadIdx.x; i < 16; i++) + //quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [M, K] + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit) + { + const int inner_idx_halved = inner_idx/2; + + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. + const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize)); + + local_absmax = __ldg(&(absmax[absidx])); + + if(row_B < M) + { + if((inner_idx_halved + num_values_8bit) < (K/2)) + { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + if((inner_idx_halved) + j < (K/2)) + local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + for(int i = 0; i < 4; i++) + { + #pragma unroll + for(int k = 0; k < num_values_8bit/4; k++) + { + #if BNB_BF16_AVAILABLE + local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; + local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; + #else + // bf16 multipliation not supported + local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); + local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); + #endif + } + + if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) + { + // this is also relatively important for performance + if(BITS==16) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + i]; + } + else + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; + } + + } + else + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + if(inner_idx + (i*num_values_4bit/4) + k < K) + local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)]; + else + local_A[k] = T(0.0f); + + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + { + #if BNB_BF16_AVAILABLE + local_C += (float)(local_A[k]*local_B[k]); + #else + // bf16 multipliation not supported + local_C += ((float)local_A[k]*(float)local_B[k]); + #endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if(row_B < M && warp_lane == 0) + out[row_B] = T(local_C); + +} + + +template __global__ void kfunc(T *A, T *B, T value, long n) +{ + for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); + +// these are not used and make no sense, but the compiler needs them +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); + +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, hip_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); + +template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); +template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); + +template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n); +template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n); + +#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ + float* state1, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(LION, half) +MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, hip_bfloat16) + +#define MAKE_Optimizer32bit1State(oname, gtype) \ +template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_Optimizer32bit1State(MOMENTUM, half) +MAKE_Optimizer32bit1State(MOMENTUM, float) +MAKE_Optimizer32bit1State(MOMENTUM, hip_bfloat16) +MAKE_Optimizer32bit1State(RMSPROP, half) +MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(RMSPROP, hip_bfloat16) +MAKE_Optimizer32bit1State(LION, half) +MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, hip_bfloat16) +MAKE_Optimizer32bit1State(ADAGRAD, half) +MAKE_Optimizer32bit1State(ADAGRAD, float) +MAKE_Optimizer32bit1State(ADAGRAD, hip_bfloat16) + +#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ + float* state1, float* state2, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit2State(ADAM, float) +MAKE_PreconditionOptimizer32bit2State(ADAM, half) +MAKE_PreconditionOptimizer32bit2State(ADAM, hip_bfloat16) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, hip_bfloat16) + +template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + + +#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ + float *unorm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + const float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit1State(MOMENTUM, half) +MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +MAKE_PreconditionStatic8bit1State(RMSPROP, half) +MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, half) +MAKE_PreconditionStatic8bit1State(LION, float) +MAKE_PreconditionStatic8bit1State(ADAGRAD, half) +MAKE_PreconditionStatic8bit1State(ADAGRAD, float) + +#define MAKE_optimizerStatic8bit1State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit1State(MOMENTUM, half) +MAKE_optimizerStatic8bit1State(MOMENTUM, float) +MAKE_optimizerStatic8bit1State(RMSPROP, half) +MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, half) +MAKE_optimizerStatic8bit1State(LION, float) +MAKE_optimizerStatic8bit1State(ADAGRAD, half) +MAKE_optimizerStatic8bit1State(ADAGRAD, float) + +#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ + float *unorm, \ + const float beta1, const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit2State(ADAM, half) +MAKE_PreconditionStatic8bit2State(ADAM, float) + +#define MAKE_optimizerStatic8bit2State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit2State(ADAM, half) +MAKE_optimizerStatic8bit2State(ADAM, float) + +template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); +template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); + +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) + +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); + +#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ + const float beta1, const float beta2, const float beta3, const float alpha, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* absmax1, float* absmax2, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, hip_bfloat16, 256, 1) + +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit1StateBlockwise( \ + gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* absmax1, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, hip_bfloat16, 256, 1) + +template __device__ void printnonzero(float *A, int num_values, const char*strval); +template __device__ void printnonzero(half *A, int num_values, const char*strval); diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh new file mode 100644 index 000000000..2895012f8 --- /dev/null +++ b/csrc/kernels_hip.cuh @@ -0,0 +1,132 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#ifndef kernels +#define kernels + + +template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); + +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); + +template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); + +template +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, + const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, const float gnorm_scale, const int n); + + + +template +__global__ void +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, const float gnorm_scale, const int n); + +template __global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); + +template __global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n); + + +template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); + + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); + + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, + half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); + +template __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols); +template __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kfunc(T *A, T *B, T value, long n); + +#endif diff --git a/csrc/ops.hip b/csrc/ops.hip new file mode 100644 index 000000000..4d077d19a --- /dev/null +++ b/csrc/ops.hip @@ -0,0 +1,836 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#ifndef NO_HIPBLASLT +#include +#endif +#include +#include +#include +#include + +#define ERR_NOT_IMPLEMENTED 100 + +using namespace BinSearch; +using std::cout; +using std::endl; + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) +{ + int threads = 512; + int num_blocks = n/threads; + num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kHistogramScatterAdd2D), dim3(num_blocks), dim3(512), 0, 0, histogram, index1, index2, src, maxidx1, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void estimateQuantiles(T *A, float *code, float offset, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float))); + hipLaunchKernelGGL(( kEstimateQuantiles), dim3(num_blocks), dim3(512), 0, 0, A, code, offset, std::numeric_limits::max(), n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void quantize(float *code, float *A, unsigned char *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kQuantize), dim3(num_blocks), dim3(1024), 0, 0, code, A, out, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void dequantize(float *code, unsigned char *A, float *out, int n, hipStream_t stream) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kDequantize), dim3(num_blocks), dim3(1024), 0, stream, code, A, out, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + + if(blocksize == 4096) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(1024), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 2048) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(512), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 1024) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 512) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 256) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 128) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); + //else if(blocksize == 64) + // hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); + + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, hipStream_t stream) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + + if(DATA_TYPE > 0) + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize/2, n); + else + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize, n); + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + + + + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, + const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + switch(OPTIMIZER) + { + case ADAM: + case ADEMAMIX: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit2State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + hipLaunchKernelGGL(( kOptimizer32bit2State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case LION: + // in lion, the momentum update after the parameter update + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + break; + } +} + +template void optimizerStatic8bit(T* p, T* g, + unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + + if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); } + + switch(OPTIMIZER) + { + case ADAM: + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit2State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + hipLaunchKernelGGL(( kOptimizerStatic8bit2State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case LION: + // in lion, the momentum update happens after the parameter update + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + default: + break; + } +} + +#define BLOCKSIZE_2STATE 256 +#define NUM_2STATE 1 +#define BLOCKSIZE_1STATE 256 +#define NUM_1STATE 1 + +template void optimizerStatic8bitBlockwise( + T* p, + T* g, + unsigned char* state1, + unsigned char* state2, + float beta1, + float beta2, + float beta3, + float alpha, + float eps, + int step, + float lr, + float* quantiles1, + float* quantiles2, + float* absmax1, + float* absmax2, + float weight_decay, + const float gnorm_scale, + bool skip_zeros, + int n +) { + + int num_blocks = 0; + switch(OPTIMIZER) + { + case ADAM: + case ADEMAMIX: + num_blocks = n/BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kOptimizerStatic8bit2StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_2STATE/NUM_2STATE), 0, 0, p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, + quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + case LION: + num_blocks = n/BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kOptimizerStatic8bit1StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_1STATE/NUM_1STATE), 0, 0, p, g, state1, beta1, beta2, eps, step, lr, + quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + } +} + + + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) +{ + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPercentileClipping), dim3(num_blocks), dim3(512), 0, 0, g, gnorm_vec, step, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + hipblasStatus_t status; + + status = hipblasGemmEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, + C, HIPBLAS_R_32I, ldc, + HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); + + if (status != HIPBLAS_STATUS_SUCCESS) + { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } + +} + +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + hipblasStatus_t status; + + //cout << transposeA << transposeB << endl; + //printf("%i %i %i\n", m,n,k); + //printf("%i %i %i\n", lda,ldb,ldc); + //printf("%i %i %i\n", strideA, strideB, strideC); + //printf("%i\n", batchCount); + + status = hipblasGemmStridedBatchedEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, + C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, + HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); + + if (status != HIPBLAS_STATUS_SUCCESS) + { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } + +} + +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + +#ifdef NO_HIPBLASLT +#else +template hipblasLtOrder_t get_order() +{ + switch(ORDER) + { + case ROW: + return HIPBLASLT_ORDER_ROW; + break; + case COL: + return HIPBLASLT_ORDER_COL; + break; + case COL32: + //return HIPBLASLT_ORDER_COL32; + return HIPBLASLT_ORDER_COL; + break; + case COL_TURING: + //return HIPBLASLT_ORDER_COL4_4R2_8C; + return HIPBLASLT_ORDER_COL; + break; + case COL_AMPERE: + //return HIPBLASLT_ORDER_COL32_2R_4R4; + return HIPBLASLT_ORDER_COL; + break; + default: + break; + } + + return HIPBLASLT_ORDER_ROW; +} + +template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); +//template hipblasLtOrder_t get_order(); +//template hipblasLtOrder_t get_order(); +#endif + +template int get_leading_dim(int dim1, int dim2) +{ + switch(ORDER) + { + case ROW: + return dim2; + break; + case COL: + return dim1; + break; + default: + return dim1; + break; + /*case COL32: + // 32*row tiles + return dim1*32; + break; + case COL_TURING: + return 32*roundoff(dim1, 8); + break; + case COL_AMPERE: + // 32*32 tiles + return 32*roundoff(dim1, 32); + break; + default: + return 0; + break; + */ + } +} + +static std::string hipError_to_string(const hipError_t ret) +{ + switch(ret) + { + case hipSuccess: + return "hipSuccess"; + case hipErrorInvalidContext: + return "hipErrorInvalidContext"; + case hipErrorInvalidKernelFile: + return "hipErrorInvalidKernelFile"; + case hipErrorMemoryAllocation: + return "hipErrorMemoryAllocation"; + case hipErrorInitializationError: + return "hipErrorInitializationError"; + case hipErrorLaunchFailure: + return "hipErrorLaunchFailure"; + case hipErrorLaunchOutOfResources: + return "hipErrorLaunchOutOfResources"; + case hipErrorInvalidDevice: + return "hipErrorInvalidDevice"; + case hipErrorInvalidValue: + return "hipErrorInvalidValue"; + case hipErrorInvalidDevicePointer: + return "hipErrorInvalidDevicePointer"; + case hipErrorInvalidMemcpyDirection: + return "hipErrorInvalidMemcpyDirection"; + case hipErrorUnknown: + return "hipErrorUnknown"; + case hipErrorInvalidResourceHandle: + return "hipErrorInvalidResourceHandle"; + case hipErrorNotReady: + return "hipErrorNotReady"; + case hipErrorNoDevice: + return "hipErrorNoDevice"; + case hipErrorPeerAccessAlreadyEnabled: + return "hipErrorPeerAccessAlreadyEnabled"; + case hipErrorPeerAccessNotEnabled: + return "hipErrorPeerAccessNotEnabled"; + case hipErrorRuntimeMemory: + return "hipErrorRuntimeMemory"; + case hipErrorRuntimeOther: + return "hipErrorRuntimeOther"; + case hipErrorHostMemoryAlreadyRegistered: + return "hipErrorHostMemoryAlreadyRegistered"; + case hipErrorHostMemoryNotRegistered: + return "hipErrorHostMemoryNotRegistered"; + case hipErrorMapBufferObjectFailed: + return "hipErrorMapBufferObjectFailed"; + case hipErrorTbd: + return "hipErrorTbd"; + default: + throw std::runtime_error("unknown hipError"); + } +} + +template int igemmlt( + hipblasLtHandle_t ltHandle, + int m, int n, int k, + const int8_t *A, + const int8_t *B, + void *C, + float *row_scale, + int lda, int ldb, int ldc, + hipStream_t stream +) { +#ifdef NO_HIPBLASLT + return ERR_NOT_IMPLEMENTED; +#else + + // Calculate C = A^T @ B, in col-major layout. + // + // Use the IMMA kernels requires: + // * A must be transposed and B must be non-transposed. + // * Dimensions m and k must be multiples of 4. + // * All pointers must be 4-byte aligned; 16-byte alignment preferred. + + int has_error = 0; + const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel + + hipblasLtMatmulDesc_t matmulDesc; + hipblasLtMatrixLayout_t aDesc, bDesc, cDesc; + hipblasOperation_t opT = HIPBLAS_OP_T; + + hipDataType outType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_8I; + hipDataType scaleType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_32F; + + hipblasLtPointerMode_t pointerMode = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&aDesc, HIP_R_8I, m, k, lda)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&bDesc, HIP_R_8I, m, n, ldb)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc)); + + // Default layout order is col major + + has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, scaleType)); + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); + + if (DTYPE_OUT == 32) { + + /* Algo and workspace TODO: need to rework to not be duplicated */ + // Set User Preference attributes + hipblasLtMatmulPreference_t pref; + checkHipblasStatus(hipblasLtMatmulPreferenceCreate(&pref)); + checkHipblasStatus( + hipblasLtMatmulPreferenceSetAttribute(pref, + HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_workspace_size, + sizeof(max_workspace_size))); + + const int request_solutions = 1; + hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + checkHipblasStatus(hipblasLtMatmulAlgoGetHeuristic(ltHandle, + matmulDesc, + aDesc, + bDesc, + cDesc, + cDesc, + pref, + request_solutions, + heuristicResult, + &returnedAlgoCount)); + + if (returnedAlgoCount == 0) + { + has_error = 1; + fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n"); + } else { + int alpha = 1, beta = 0; + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int32_t*)C, cDesc, + (int32_t*)C, cDesc, + &heuristicResult[0].algo, NULL, 0, stream + )); + } + } else { + // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows. + + if (!SCALE_ROWS) { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, stream + )); + } else { + hipblasLtPointerMode_t alphaVec = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + float beta = 0.0f; + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute( + matmulDesc, + HIPBLASLT_MATMUL_DESC_POINTER_MODE, + &pointerMode, + sizeof(alphaVec) + )); + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + row_scale, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, stream + )); + } + } + + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(cDesc)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(bDesc)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(aDesc)); + has_error |= checkHipblasStatus(hipblasLtMatmulDescDestroy(matmulDesc)); + + if(has_error == 1) + printf("error detected"); + + return has_error; +#endif // NO_HIPBLASLT +} + +int fill_up_to_nearest_multiple(int value, int multiple) +{ + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +} + +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, hipStream_t stream) +{ + const int threads = 512; + const int num_per_thread = 4; + const int num_per_block = threads * num_per_thread; + const int n = numRows*numCols; + const int num_blocks = (n + num_per_block - 1) / num_per_block; + + hipLaunchKernelGGL(( kdequant_mm_int32_fp16), dim3(num_blocks), dim3(threads), 0, stream, A, rowStats, colStats, out, bias, numRows, numCols, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) { if (threshold == 0.0) { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } else { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) { + if (threshold == 0.0) + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + else + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) +{ + +#ifdef NO_HIPBLASLT +#else + + hipsparseSpMatDescr_t descA; + hipsparseDnMatDescr_t descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + + CHECK_HIPSPARSE( hipsparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + A_rowidx, A_colidx, A_vals, + HIPSPARSE_INDEX_32I, + HIPSPARSE_INDEX_BASE_ZERO, HIP_R_16F) ); + // Create dense matrix C + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + // allocate an external buffer if needed + CHECK_HIPSPARSE( hipsparseSpMM_bufferSize( + handle, + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); + CUDA_CHECK_RETURN( hipMalloc(&dBuffer, bufferSize) ); + + // execute SpMM + CHECK_HIPSPARSE( hipsparseSpMM(handle, + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, dBuffer)); + + // destroy matrix/vector descriptors + CHECK_HIPSPARSE( hipsparseDestroySpMat(descA) ); + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descB) ); + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descC) ); + CUDA_CHECK_RETURN( hipFree(dBuffer) ); +#endif +} + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + + hipLaunchKernelGGL(( kspmm_coo_very_sparse_naive), dim3(nnz_rows), dim3(256), 0, 0, max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + if(bits == 32) + hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(32), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(160), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); +} + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + hipLaunchKernelGGL(( kgemm_4bit_inference), dim3(num_blocks), dim3(96), 0, 0, m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream) +{ + + //warpsize - 32 + int num_blocks = (m+3)/4; + //warpsize - 64 + if (warpSize == 64) { + num_blocks = (m+1)/2; + } + + hipLaunchKernelGGL(( kgemm_4bit_inference_naive), dim3(num_blocks), dim3(128), 0, stream, m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + hipLaunchKernelGGL(( kfunc), dim3(blocks), dim3(512), 0, 0, A, B, value, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); +template void gemm_4bit_inference_naive(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); + +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template int igemmlt<32, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); +template int igemmlt<8, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); +template int igemmlt<8, 1>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); + +template void estimateQuantiles(half *A, float *code, float offset, int n); +template void estimateQuantiles(float *A, float *code, float offset, int n); + +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); + +#define MAKE_optimizer32bit(name, gtype) \ +template void optimizer32bit(gtype* g, gtype* p, \ + float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +MAKE_optimizer32bit(ADAM, half) +MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, hip_bfloat16) +MAKE_optimizer32bit(MOMENTUM, half) +MAKE_optimizer32bit(MOMENTUM, float) +MAKE_optimizer32bit(MOMENTUM, hip_bfloat16) +MAKE_optimizer32bit(RMSPROP, half) +MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(RMSPROP, hip_bfloat16) +MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(LION, hip_bfloat16) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) +MAKE_optimizer32bit(ADAGRAD, hip_bfloat16) +MAKE_optimizer32bit(ADEMAMIX, half) +MAKE_optimizer32bit(ADEMAMIX, hip_bfloat16) +MAKE_optimizer32bit(ADEMAMIX, float) + +#define MAKE_optimizerStatic8bit(name, gtype) \ +template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, int n); \ + +MAKE_optimizerStatic8bit(ADAM, half) +MAKE_optimizerStatic8bit(ADAM, float) +MAKE_optimizerStatic8bit(MOMENTUM, half) +MAKE_optimizerStatic8bit(MOMENTUM, float) +MAKE_optimizerStatic8bit(RMSPROP, half) +MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, float) +MAKE_optimizerStatic8bit(ADAGRAD, half) +MAKE_optimizerStatic8bit(ADAGRAD, float) + +#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ +template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ + +MAKE_optimizerStatic8bitBlockwise(half, ADAM); +MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAM); +MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADEMAMIX); +MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); + +template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh new file mode 100644 index 000000000..bcfc73e99 --- /dev/null +++ b/csrc/ops_hip.cuh @@ -0,0 +1,195 @@ +// !!! This is a file automatically generated by hipify!!! +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + + +#ifndef ops_H +#define ops_H + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define CUDA_CHECK_RETURN(value) { \ + hipError_t _m_cudaStat = value; \ + if (_m_cudaStat != hipSuccess) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } } + + +#define CHECK_HIPSPARSE(value) { \ + hipsparseStatus_t _m_hipStat = value; \ + if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + hipsparseGetErrorString(_m_hipStat), __LINE__, __FILE__); \ + exit(1); \ + } } + + + +inline void checkHipStatus(hipError_t status) { + if (status != hipSuccess) { + printf("hip API failed with status %d: %s\n", status, hipGetErrorString(status)); + throw std::logic_error("hip API failed"); + } +} + +inline int checkHipblasStatus(hipblasStatus_t status) { + if (status != HIPBLAS_STATUS_SUCCESS) { + printf("hipBLAS API failed with status %d\n", status); + //throw std::logic_error("cuBLAS API failed"); + return 1; + } + return 0; +} + +typedef enum Operations_t +{ + ksmul = 0, +} Operations_t; + +typedef enum Optimizer_t +{ + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, + ADEMAMIX = 6, +} Optimizer_t; + +typedef enum Transform_t +{ + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, +} Transform_t; + +typedef enum DataType_t +{ + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + +class Context +{ + public: + rocblas_handle m_handle; + + Context() + { + rocblas_handle handle; + rocblas_create_handle(&handle); + m_handle = handle; + } + +}; + +class ContextLt +{ + public: + hipblasLtHandle_t m_handle; + + ContextLt() + { + hipblasLtHandle_t handle; + hipblasLtCreate(&handle); + m_handle = handle; + } +}; + +class ContextHipsparse +{ + public: + hipsparseHandle_t m_handle; + + ContextHipsparse() + { + hipsparseHandle_t handle; + hipsparseCreate(&handle); + m_handle = handle; + } + +}; + + +template void estimateQuantiles(T *A, float *code, float offset, int n); + +void quantize(float *code, float *A, unsigned char *out, int n); +void dequantize(float *code, unsigned char *A, float *out, int n, hipStream_t stream); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, hipStream_t stream); + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, + int step, float lr, const float gnorm_scale, bool skip_zeros, int n); + +template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n); + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + bool skip_zeros, int n); + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); + +void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount); + + +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); + +void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, hipStream_t stream); +void getRowStats(half * A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream); +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, hipStream_t stream); + +void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); + +template void func(T *A, T *B, T value, long n); + +#endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 56bec82e8..a48514542 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -6,11 +6,29 @@ #if BUILD_CUDA #include #endif +#if BUILD_HIP +#include +#endif #if BUILD_MPS // #include #endif #include +// Compatibility between HIP/CUDA APIs +#if BUILD_HIP +#define cudaStream_t hipStream_t +#define __nv_bfloat16 hip_bfloat16 +#define cublasLtHandle_t hipblasLtHandle_t +#define ContextCusparse ContextHipsparse +#define cusparseHandle_t hipsparseHandle_t +#define cudaMallocManaged hipMallocManaged +#define cudaMemAttachHost hipMemAttachHost +#define cudaPeekAtLastError hipPeekAtLastError +#define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess +#define cudaMemPrefetchAsync hipMemPrefetchAsync +#endif + // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to // maintain all that boilerplate @@ -18,7 +36,7 @@ // UNMANGLED CALLS //=================================================================================== -#if BUILD_CUDA +#if BUILD_CUDA || BUILD_HIP void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } @@ -168,7 +186,7 @@ void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_r extern "C" { -#if BUILD_CUDA +#if BUILD_CUDA || BUILD_HIP void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } From d729c188496ce5947f159693fbbb3e2dd281d87e Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 20 May 2025 21:14:15 +0530 Subject: [PATCH 002/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 363 ++++++++++++++++++------------ 1 file changed, 223 insertions(+), 140 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index efdef2871..fd63c888d 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,14 +1,15 @@ from collections.abc import Sequence import ctypes as ct from math import prod -from typing import Optional +from typing import Optional import torch from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel -from ...cextension import lib +from ...cextension import lib, HIP_ENVIRONMENT + @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -84,7 +85,6 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor return out - @register_kernel("bitsandbytes::int8_mm_dequant", "cuda") def _( A: torch.Tensor, @@ -164,7 +164,7 @@ def _(A: torch.Tensor, threshold=0.0): out_row[:, outlier_cols] = 0 return out_row, row_stats, outlier_cols - + @register_kernel("bitsandbytes::int8_double_quant", "cuda") def _( @@ -210,35 +210,67 @@ def _get_col_absmax( @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - + + device = A.device + device_type = device.type + + if device_type == 'cuda': + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + elif device_type == 'hip' and HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + elif device_type == 'cpu': + cpu_kernel_func = getattr(lib, 'cquantize_blockwise_cpu_fp32', None) + if cpu_kernel_func: + A_cpu = A.to(torch.float32) if A.dtype != torch.float32 else A + code_cpu = code.to('cpu') + absmax_cpu = torch.empty(absmax.shape, device='cpu', dtype=torch.float32) + out_cpu = torch.empty(out.shape, device='cpu', dtype=torch.uint8) + + cpu_kernel_func( + get_ptr(code_cpu), + get_ptr(A_cpu), + get_ptr(absmax_cpu), + get_ptr(out_cpu), + ct.c_longlong(blocksize), + ct.c_longlong(A_cpu.numel()) + ) + + out.copy_(out_cpu) + absmax.copy_(absmax_cpu) + else: + raise NotImplementedError("CPU blockwise quantization requires C extension support") + else: + raise NotImplementedError(f"Blockwise quantization not implemented for {device_type}") + + return out, absmax + @register_kernel("bitsandbytes::dequantize_blockwise", "cuda") def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: @@ -252,7 +284,7 @@ def _( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, - blocksize: int, + blocksize: int, dtype: torch.dtype, out: torch.Tensor, ) -> None: @@ -264,76 +296,116 @@ def _( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) + device = A.device + device_type = device.type + + if device_type == 'cuda': + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + elif device_type == 'hip' and HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + elif device_type == 'cpu': + cpu_kernel_func = getattr(lib, 'cdequantize_blockwise_cpu_fp32', None) + if cpu_kernel_func: + code_cpu = code.to('cpu') + A_cpu = A.to('cpu') + absmax_cpu = absmax.to('cpu') + out_cpu = torch.empty(out.shape, dtype=torch.float32, device='cpu') + + cpu_kernel_func( + get_ptr(code_cpu), + get_ptr(A_cpu), + get_ptr(absmax_cpu), + get_ptr(out_cpu), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()) + ) + + out.copy_(out_cpu.to(dtype)) + else: + raise NotImplementedError("CPU blockwise dequantization requires C extension support") + else: + raise NotImplementedError(f"Blockwise dequantization not implemented for {device_type}") @register_kernel("bitsandbytes::quantize_4bit", "cuda") def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax + device = A.device + device_type = device.type + + if device_type == 'cuda': + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + elif device_type == 'hip' or HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + else: + raise NotImplementedError(f"4-bit quantization not implemented for {device_type}") + + return out, absmax @register_kernel("bitsandbytes::dequantize_4bit", "cuda") def _( @@ -347,6 +419,7 @@ def _( out = torch.empty(shape, dtype=dtype, device=A.device) _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) return out + @register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") @@ -359,52 +432,62 @@ def _( dtype: torch.dtype, out: torch.Tensor, ) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + device = A.device + device_type = device.type + + if device_type == 'cuda': + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + elif device_type == 'hip' and HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + else: + raise NotImplementedError(f"4-bit dequantization not implemented for {device_type}") @register_kernel("bitsandbytes::gemv_4bit", "cuda") @@ -457,7 +540,7 @@ def _gemv_4bit_impl( B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") m = ct.c_int32(shapeB[0]) From 6459c2bd6e4eb68fbe36d3deb200ac3492f96c1a Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 20 May 2025 21:15:00 +0530 Subject: [PATCH 003/102] Update functional.py --- bitsandbytes/functional.py | 391 ++++++++++++++++++++++--------------- 1 file changed, 238 insertions(+), 153 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b0092ffd1..7730f7182 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import lib +from .cextension import lib, HIP_ENVIRONMENT name2qmap = {} @@ -719,152 +719,222 @@ def __eq__(self, other): ) -def quantize_blockwise( - A: torch.Tensor, - code: Optional[torch.Tensor] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize=4096, - nested=False, -) -> tuple[torch.Tensor, QuantState]: - """Quantize a tensor in blocks of values. - - The input tensor is quantized by dividing it into blocks of `blocksize` values. - The the absolute maximum value within these blocks is calculated for scaling - the non-linear quantization. - - Args: - A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. - code (`torch.Tensor`, *optional*): - A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. - For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. - absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. - out (`torch.Tensor`, *optional*): A tensor to use to store the result. - blocksize (`int`, *optional*): - The size of the blocks. Defaults to 4096. - Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. - - Raises: - ValueError: Raised when the input data type is not supported. - - Returns: - `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. - - `torch.Tensor`: The quantized tensor. - - [`QuantState`]: The state object used to undo the quantization. - """ - - if code is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - - _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( - A, - code.to(A.device), - blocksize, - ) - - if nested: - offset = _absmax.mean() - _absmax -= offset - qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) - quant_state = QuantState( - absmax=qabsmax, - code=code, - blocksize=blocksize, - dtype=A.dtype, - offset=offset, - state2=state2, - ) - else: - quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) - - # TODO(matthewdouglas): Deprecate out kwarg - out = out.copy_(_out) if out is not None else _out - - # TODO(matthewdouglas): Deprecate absmax kwarg - if absmax is not None: - quant_state.absmax = absmax.copy_(quant_state.absmax) - - return out, quant_state - - -def dequantize_blockwise( - A: torch.Tensor, - quant_state: Optional[QuantState] = None, - absmax: Optional[torch.Tensor] = None, - code: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 4096, - nested=False, -) -> torch.Tensor: - """Dequantize a tensor in blocks of values. - - The input tensor is dequantized by dividing it into blocks of `blocksize` values. - The the absolute maximum value within these blocks is used for scaling - the non-linear dequantization. - - Args: - A (`torch.Tensor`): The quantized input tensor. - quant_state ([`QuantState`], *optional*): - The quantization state as returned by [`quantize_blockwise`]. - Required if `absmax` is not provided. - absmax (`torch.Tensor`, *optional*): - A tensor containing the scaling values. - Required if `quant_state` is not provided and ignored otherwise. - code (`torch.Tensor`, *optional*): - A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. - For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. - Ignored when `quant_state` is provided. - out (`torch.Tensor`, *optional*): A tensor to use to store the result. - blocksize (`int`, *optional*): - The size of the blocks. Defaults to 4096. - Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - Ignored when `quant_state` is provided. - - Raises: - ValueError: Raised when the input data type is not supported. - - Returns: - `torch.Tensor`: - The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. - """ - - assert quant_state is not None or absmax is not None - if code is None and quant_state is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - - if quant_state is None: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) - - absmax = quant_state.absmax - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - - if out is not None: - torch.ops.bitsandbytes.dequantize_blockwise.out( - A, - absmax, - code.to(A.device), - blocksize, - quant_state.dtype, - out=out, - ) - return out - - return torch.ops.bitsandbytes.dequantize_blockwise.default( - A, - absmax, - quant_state.code.to(A.device), - quant_state.blocksize, - quant_state.dtype, - ) +def quantize_blockwise( + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, +) -> tuple[torch.Tensor, QuantState]: + """Quantize a tensor in blocks of values. + + The input tensor is quantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is calculated for scaling + the non-linear quantization. + + Args: + A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. + + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. + - `torch.Tensor`: The quantized tensor. + - [`QuantState`]: The state object used to undo the quantization. + """ + + if code is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + + if absmax is None: + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) + + device_type = A.device.type + + if device_type == "cpu": + code = code.cpu() + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif device_type in ["cuda", "hip"]: + if not HIP_ENVIRONMENT: + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + else: + assert blocksize in [4096, 2048, 1024, 512, 256, 128] + + code = code.to(A.device) + + is_on_gpu([A, out, absmax]) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + else: + raise RuntimeError(f"Device type {device_type} not supported for quantization") + + if nested: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) + quant_state = QuantState( + absmax=qabsmax, + code=code, + blocksize=blocksize, + dtype=A.dtype, + offset=offset, + state2=state2, + ) + else: + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) + + return out, quant_state + + +def dequantize_blockwise( + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, +) -> torch.Tensor: + """Dequantize a tensor in blocks of values. + + The input tensor is dequantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is used for scaling + the non-linear dequantization. + + Args: + A (`torch.Tensor`): The quantized input tensor. + quant_state ([`QuantState`], *optional*): + The quantization state as returned by [`quantize_blockwise`]. + Required if `absmax` is not provided. + absmax (`torch.Tensor`, *optional*): + A tensor containing the scaling values. + Required if `quant_state` is not provided and ignored otherwise. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + Ignored when `quant_state` is provided. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + Ignored when `quant_state` is provided. + + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + `torch.Tensor`: + The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. + """ + + assert quant_state is not None or absmax is not None + if code is None and quant_state is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + + if quant_state is None: + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + + absmax = quant_state.absmax + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + if out is None: + out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) + + device_type = A.device.type + + if device_type == "cpu": + code = quant_state.code.cpu() + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(quant_state.absmax), + get_ptr(out), + ct.c_longlong(quant_state.blocksize), + ct.c_longlong(A.numel()), + ) + elif device_type in ["cuda", "hip"]: + code = quant_state.code.to(A.device) + supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] + if HIP_ENVIRONMENT: + supported_blocksizes = supported_blocksizes[:-1] + if quant_state.blocksize not in supported_blocksizes: + raise ValueError( + f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}", + ) + + is_on_gpu([A, absmax, out]) + + with _cuda_device_of(A): + args = ( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif out.dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif out.dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + else: + raise RuntimeError(f"Device type {device_type} not supported for dequantization") + + return out def get_4bit_type(typename, device=None, blocksize=64): @@ -953,10 +1023,12 @@ def quantize_fp4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -964,10 +1036,12 @@ def quantize_nf4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -975,7 +1049,7 @@ def quantize_4bit( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, @@ -1003,6 +1077,9 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + input_shape = A.shape _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( @@ -1053,8 +1130,10 @@ def dequantize_fp4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1063,8 +1142,10 @@ def dequantize_nf4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1073,7 +1154,7 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, quant_type="fp4", ) -> torch.Tensor: """Dequantizes a packed 4-bit quantized tensor. @@ -1102,6 +1183,10 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ + + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + if quant_state is None: assert absmax is not None and out is not None From 09249c897e47708ea9d4e594b8deaea439d74ade Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 21 May 2025 20:12:20 +0530 Subject: [PATCH 004/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 106 +++++++++++++----------------- 1 file changed, 44 insertions(+), 62 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index fd63c888d..40f25a18f 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,3 +1,4 @@ + from collections.abc import Sequence import ctypes as ct from math import prod @@ -5,7 +6,7 @@ import torch -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr, is_on_gpu from ..._ops import register_kernel from ...cextension import lib, HIP_ENVIRONMENT @@ -43,7 +44,7 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor n = prod(shapeB[:-1]) lda = shapeA[-1] # Weights (outputs, inputs) ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) torch._check( lda == ldb, @@ -53,10 +54,18 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. # We'll fall back to a slower fp32 calculation in this circumstance. # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + if out is not None: + result = out.copy_(result) + return result + + if out is None: + out = torch.empty(shapeC, device=A.device, dtype=dtype) + + is_on_gpu([A, B, out]) + with _cuda_device_of(A): ctx = CUBLAS_Context.get_instance().get_context(A.device) ptrA = get_ptr(A) @@ -71,8 +80,11 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor ldc = ct.c_int32(ldc) stream = _get_tensor_stream(A) - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - + if dtype == torch.int32: + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + else: + has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + if has_error: if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` @@ -111,6 +123,8 @@ def _( # Note: fused bias in the kernel is only supported for fp16 # TODO(matthewdouglas): Consider supporting bf16 fused bias ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + is_on_gpu([A, row_stats, col_stats, out, bias]) with _cuda_device_of(A): lib.cdequant_mm_int32_fp16( @@ -128,6 +142,8 @@ def _( def _(A: torch.Tensor, threshold=0.0): torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + is_on_gpu([A]) rows = prod(A.shape[:-1]) cols = A.shape[-1] @@ -216,7 +232,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor if device_type == 'cuda': torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' and HIP_ENVIRONMENT: + elif device_type == 'hip' or HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") @@ -225,8 +241,10 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty_like(A, dtype=torch.uint8) - - if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + + is_on_gpu([A, out, absmax]) + + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( get_ptr(code), @@ -245,30 +263,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor lib.cquantize_blockwise_fp32(*args) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - elif device_type == 'cpu': - cpu_kernel_func = getattr(lib, 'cquantize_blockwise_cpu_fp32', None) - if cpu_kernel_func: - A_cpu = A.to(torch.float32) if A.dtype != torch.float32 else A - code_cpu = code.to('cpu') - absmax_cpu = torch.empty(absmax.shape, device='cpu', dtype=torch.float32) - out_cpu = torch.empty(out.shape, device='cpu', dtype=torch.uint8) - - cpu_kernel_func( - get_ptr(code_cpu), - get_ptr(A_cpu), - get_ptr(absmax_cpu), - get_ptr(out_cpu), - ct.c_longlong(blocksize), - ct.c_longlong(A_cpu.numel()) - ) - - out.copy_(out_cpu) - absmax.copy_(absmax_cpu) - else: - raise NotImplementedError("CPU blockwise quantization requires C extension support") - else: - raise NotImplementedError(f"Blockwise quantization not implemented for {device_type}") - + return out, absmax @@ -302,7 +297,7 @@ def _dequantize_blockwise_impl( if device_type == 'cuda': torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' and HIP_ENVIRONMENT: + elif device_type == 'hip' or HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") @@ -310,8 +305,10 @@ def _dequantize_blockwise_impl( dtype in [torch.float16, torch.bfloat16, torch.float32], lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", ) - - if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + + is_on_gpu([A, absmax, out]) + + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( get_ptr(code), @@ -328,29 +325,8 @@ def _dequantize_blockwise_impl( elif dtype == torch.bfloat16: lib.cdequantize_blockwise_bf16(*args) elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - elif device_type == 'cpu': - cpu_kernel_func = getattr(lib, 'cdequantize_blockwise_cpu_fp32', None) - if cpu_kernel_func: - code_cpu = code.to('cpu') - A_cpu = A.to('cpu') - absmax_cpu = absmax.to('cpu') - out_cpu = torch.empty(out.shape, dtype=torch.float32, device='cpu') - - cpu_kernel_func( - get_ptr(code_cpu), - get_ptr(A_cpu), - get_ptr(absmax_cpu), - get_ptr(out_cpu), - ct.c_longlong(blocksize), - ct.c_longlong(A.numel()) - ) - - out.copy_(out_cpu.to(dtype)) - else: - raise NotImplementedError("CPU blockwise dequantization requires C extension support") - else: - raise NotImplementedError(f"Blockwise dequantization not implemented for {device_type}") + lib.cdequantize_blockwise_fp32(*args) + @register_kernel("bitsandbytes::quantize_4bit", "cuda") def _( @@ -375,7 +351,9 @@ def _( blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - + + is_on_gpu([A, out, absmax]) + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -450,7 +428,7 @@ def _dequantize_4bit_impl( if device_type == 'cuda': torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' and HIP_ENVIRONMENT: + elif device_type == 'hip' or HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) torch._check(quant_type in ["fp4", "nf4"]) @@ -459,6 +437,8 @@ def _dequantize_4bit_impl( lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", ) + is_on_gpu([A, absmax, out]) + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -550,6 +530,8 @@ def _gemv_4bit_impl( lda = m ldb = ct.c_int32((A.shape[-1] + 1) // 2) ldc = m + + is_on_gpu([B, A, out, absmax]) stream = _get_tensor_stream(A) From 4afa7741b3b7105ac6a42700dab1fd83b5050fc5 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 21 May 2025 20:12:36 +0530 Subject: [PATCH 005/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 40f25a18f..ce5401c5f 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,4 +1,3 @@ - from collections.abc import Sequence import ctypes as ct from math import prod From 033d92cef2d41431fd4247c272c9429f7304bf40 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 21 May 2025 20:23:34 +0530 Subject: [PATCH 006/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index ce5401c5f..14f55847c 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -353,7 +353,7 @@ def _( is_on_gpu([A, out, absmax]) - if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( None, @@ -438,7 +438,7 @@ def _dequantize_4bit_impl( is_on_gpu([A, absmax, out]) - if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( None, From 4def9590abb8a3f0ef789fce0b1659af729643e4 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 22 May 2025 20:51:50 +0530 Subject: [PATCH 007/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 14f55847c..5b94c5349 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -11,7 +11,6 @@ from ...cextension import lib, HIP_ENVIRONMENT - @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") def _(A: torch.Tensor, B: torch.Tensor): out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) @@ -78,12 +77,9 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) stream = _get_tensor_stream(A) - - if dtype == torch.int32: - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - else: - has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + if has_error: if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` @@ -96,6 +92,7 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor return out + @register_kernel("bitsandbytes::int8_mm_dequant", "cuda") def _( A: torch.Tensor, @@ -384,6 +381,7 @@ def _( return out, absmax + @register_kernel("bitsandbytes::dequantize_4bit", "cuda") def _( A: torch.Tensor, @@ -398,7 +396,6 @@ def _( return out - @register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") def _( A: torch.Tensor, @@ -496,7 +493,6 @@ def _( torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - def _gemv_4bit_impl( A: torch.Tensor, B: torch.Tensor, From 0f318667aaf4de15cd29f8063dcaa4fd90d24783 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 22 May 2025 21:31:55 +0530 Subject: [PATCH 008/102] Update functional.py --- bitsandbytes/functional.py | 157 ++++++++++++------------------------- 1 file changed, 48 insertions(+), 109 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 7730f7182..3f0c1ff94 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -728,11 +728,9 @@ def quantize_blockwise( nested=False, ) -> tuple[torch.Tensor, QuantState]: """Quantize a tensor in blocks of values. - The input tensor is quantized by dividing it into blocks of `blocksize` values. The the absolute maximum value within these blocks is calculated for scaling the non-linear quantization. - Args: A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. code (`torch.Tensor`, *optional*): @@ -744,10 +742,8 @@ def quantize_blockwise( The size of the blocks. Defaults to 4096. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. - Raises: ValueError: Raised when the input data type is not supported. - Returns: `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. - `torch.Tensor`: The quantized tensor. @@ -759,61 +755,23 @@ def quantize_blockwise( name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - if absmax is None: - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - - if out is None: - out = torch.zeros_like(A, dtype=torch.uint8) - device_type = A.device.type - - if device_type == "cpu": - code = code.cpu() - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), - ) - elif device_type in ["cuda", "hip"]: + if device_type in ["cuda", "hip"]: if not HIP_ENVIRONMENT: - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] else: - assert blocksize in [4096, 2048, 1024, 512, 256, 128] - - code = code.to(A.device) - - is_on_gpu([A, out, absmax]) + assert blocksize in [4096, 2048, 1024, 512, 256, 128] - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - else: - raise RuntimeError(f"Device type {device_type} not supported for quantization") + _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( + A, + code.to(A.device), + blocksize, + ) if nested: - offset = absmax.mean() - absmax -= offset - qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) + offset = _absmax.mean() + _absmax -= offset + qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) quant_state = QuantState( absmax=qabsmax, code=code, @@ -823,11 +781,18 @@ def quantize_blockwise( state2=state2, ) else: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) + quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) + + # TODO(matthewdouglas): Deprecate out kwarg + out = out.copy_(_out) if out is not None else _out + + # TODO(matthewdouglas): Deprecate absmax kwarg + if absmax is not None: + quant_state.absmax = absmax.copy_(quant_state.absmax) + + return out, quant_state + - return out, quant_state - - def dequantize_blockwise( A: torch.Tensor, quant_state: Optional[QuantState] = None, @@ -838,11 +803,9 @@ def dequantize_blockwise( nested=False, ) -> torch.Tensor: """Dequantize a tensor in blocks of values. - The input tensor is dequantized by dividing it into blocks of `blocksize` values. The the absolute maximum value within these blocks is used for scaling the non-linear dequantization. - Args: A (`torch.Tensor`): The quantized input tensor. quant_state ([`QuantState`], *optional*): @@ -860,10 +823,8 @@ def dequantize_blockwise( The size of the blocks. Defaults to 4096. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. Ignored when `quant_state` is provided. - Raises: ValueError: Raised when the input data type is not supported. - Returns: `torch.Tensor`: The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. @@ -878,6 +839,16 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + device_type = A.device.type + if device_type in ["cuda", "hip"]: + supported_blocksizes = [4096, 2048, 1024, 512, 256, 128, 64] + if HIP_ENVIRONMENT: + supported_blocksizes = supported_blocksizes[:-1] + if quant_state.blocksize not in supported_blocksizes: + raise ValueError( + f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" + ) + absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) @@ -885,56 +856,24 @@ def dequantize_blockwise( if absmax.dtype != torch.float32: absmax = absmax.float() - if out is None: - out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) - - device_type = A.device.type - - if device_type == "cpu": - code = quant_state.code.cpu() - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(quant_state.absmax), - get_ptr(out), - ct.c_longlong(quant_state.blocksize), - ct.c_longlong(A.numel()), + if out is not None: + torch.ops.bitsandbytes.dequantize_blockwise.out( + A, + absmax, + quant_state.code.to(A.device), + quant_state.blocksize, + quant_state.dtype, + out=out, ) - elif device_type in ["cuda", "hip"]: - code = quant_state.code.to(A.device) - supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] - if HIP_ENVIRONMENT: - supported_blocksizes = supported_blocksizes[:-1] - if quant_state.blocksize not in supported_blocksizes: - raise ValueError( - f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}", - ) - - is_on_gpu([A, absmax, out]) - - with _cuda_device_of(A): - args = ( - get_ptr(quant_state.code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif out.dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") - else: - raise RuntimeError(f"Device type {device_type} not supported for dequantization") + return out - return out + return torch.ops.bitsandbytes.dequantize_blockwise.default( + A, + absmax, + quant_state.code.to(A.device), + quant_state.blocksize, + quant_state.dtype, + ) def get_4bit_type(typename, device=None, blocksize=64): From 190faed7e96b8b27e033fe3c6ee5e3a6d5a4772a Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 22 May 2025 23:35:15 +0530 Subject: [PATCH 009/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 5b94c5349..ff5e023cc 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -52,15 +52,9 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. # We'll fall back to a slower fp32 calculation in this circumstance. # Fortunately, this should not be very common. - - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - if out is not None: - result = out.copy_(result) - return result - - if out is None: - out = torch.empty(shapeC, device=A.device, dtype=dtype) + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) is_on_gpu([A, B, out]) From d7f413b9b367b9b26b87180095ebcc7a561fdc26 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 22 May 2025 23:52:39 +0530 Subject: [PATCH 010/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 30 +++++------------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index ff5e023cc..b75f67d62 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -55,8 +55,6 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor if lda % 4 != 0: result = torch.matmul(B.float(), A.float().t()).to(torch.int32) return out.copy_(result) - - is_on_gpu([A, B, out]) with _cuda_device_of(A): ctx = CUBLAS_Context.get_instance().get_context(A.device) @@ -114,8 +112,6 @@ def _( # TODO(matthewdouglas): Consider supporting bf16 fused bias ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - is_on_gpu([A, row_stats, col_stats, out, bias]) - with _cuda_device_of(A): lib.cdequant_mm_int32_fp16( ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) @@ -133,8 +129,6 @@ def _(A: torch.Tensor, threshold=0.0): torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - is_on_gpu([A]) - rows = prod(A.shape[:-1]) cols = A.shape[-1] @@ -231,9 +225,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty_like(A, dtype=torch.uint8) - - is_on_gpu([A, out, absmax]) - + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -295,9 +287,7 @@ def _dequantize_blockwise_impl( dtype in [torch.float16, torch.bfloat16, torch.float32], lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", ) - - is_on_gpu([A, absmax, out]) - + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -341,9 +331,7 @@ def _( blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - is_on_gpu([A, out, absmax]) - + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -370,8 +358,6 @@ def _( lib.cquantize_blockwise_fp32_fp4(*args) else: lib.cquantize_blockwise_fp32_nf4(*args) - else: - raise NotImplementedError(f"4-bit quantization not implemented for {device_type}") return out, absmax @@ -400,11 +386,11 @@ def _( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + def _dequantize_4bit_impl( A: torch.Tensor, absmax: torch.Tensor, @@ -426,9 +412,7 @@ def _dequantize_4bit_impl( dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", ) - - is_on_gpu([A, absmax, out]) - + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -456,8 +440,6 @@ def _dequantize_4bit_impl( lib.cdequantize_blockwise_fp32_fp4(*args) else: lib.cdequantize_blockwise_fp32_nf4(*args) - else: - raise NotImplementedError(f"4-bit dequantization not implemented for {device_type}") @register_kernel("bitsandbytes::gemv_4bit", "cuda") @@ -520,8 +502,6 @@ def _gemv_4bit_impl( ldb = ct.c_int32((A.shape[-1] + 1) // 2) ldc = m - is_on_gpu([B, A, out, absmax]) - stream = _get_tensor_stream(A) with _cuda_device_of(A): From 3b6e68a001b0dce3b368129599335fcb569ac5cd Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 23 May 2025 00:05:43 +0530 Subject: [PATCH 011/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index b75f67d62..156125c9f 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -5,7 +5,7 @@ import torch -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr, is_on_gpu +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel from ...cextension import lib, HIP_ENVIRONMENT From 06740b1372a9c9751216b76dc4c8cc98514905dd Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 23 May 2025 01:53:30 +0530 Subject: [PATCH 012/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 989 +++++++++++++++--------------- 1 file changed, 486 insertions(+), 503 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 156125c9f..48dc75135 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,325 +1,312 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib, HIP_ENVIRONMENT - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - - device = A.device - device_type = device.type +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional + +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib, HIP_ENVIRONMENT + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) - if device_type == 'cuda': - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' or HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") n = A.numel() blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty_like(A, dtype=torch.uint8) - - if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - - device = A.device - device_type = device.type - - if device_type == 'cuda': - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' or HIP_ENVIRONMENT: + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + if HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check( dtype in [torch.float16, torch.bfloat16, torch.float32], lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", ) - - if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - - device = A.device - device_type = device.type - - if device_type == 'cuda': - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' or HIP_ENVIRONMENT: + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + if HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( @@ -331,66 +318,65 @@ def _( blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + def _dequantize_4bit_impl( A: torch.Tensor, absmax: torch.Tensor, @@ -399,157 +385,154 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - device = A.device - device_type = device.type - - if device_type == 'cuda': - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' or HIP_ENVIRONMENT: + if HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m - if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): - with _cuda_device_of(A): - args = ( - None, + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, get_ptr(A), + get_ptr(B), get_ptr(absmax), + get_ptr(code), get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) From 9fe67efada457a759d1d8193265243209e784e2c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 23 May 2025 02:11:31 +0530 Subject: [PATCH 013/102] Update functional.py --- bitsandbytes/functional.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 3f0c1ff94..237aa3e54 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -754,13 +754,11 @@ def quantize_blockwise( if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - - device_type = A.device.type - if device_type in ["cuda", "hip"]: - if not HIP_ENVIRONMENT: - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - else: - assert blocksize in [4096, 2048, 1024, 512, 256, 128] + + if HIP_ENVIRONMENT: + assert blocksize in [4096, 2048, 1024, 512, 256, 128] + else: + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( A, @@ -839,15 +837,15 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) - device_type = A.device.type - if device_type in ["cuda", "hip"]: + if HIP_ENVIRONMENT: + supported_blocksizes = [4096, 2048, 1024, 512, 256, 128] + else: supported_blocksizes = [4096, 2048, 1024, 512, 256, 128, 64] - if HIP_ENVIRONMENT: - supported_blocksizes = supported_blocksizes[:-1] - if quant_state.blocksize not in supported_blocksizes: - raise ValueError( - f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" - ) + + if quant_state.blocksize not in supported_blocksizes: + raise ValueError( + f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" + ) absmax = quant_state.absmax if quant_state.nested: From d97fdce654129ca156f0cb47555529d4f4941778 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 23 May 2025 02:18:37 +0530 Subject: [PATCH 014/102] Update functional.py --- bitsandbytes/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 237aa3e54..1cee234ea 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -858,8 +858,8 @@ def dequantize_blockwise( torch.ops.bitsandbytes.dequantize_blockwise.out( A, absmax, - quant_state.code.to(A.device), - quant_state.blocksize, + code.to(A.device), + blocksize, quant_state.dtype, out=out, ) From f1fbe92d2bc2eebc4629ee41a76b163772cd1874 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Sat, 24 May 2025 21:53:44 +0530 Subject: [PATCH 015/102] Update functional.py --- bitsandbytes/functional.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 1cee234ea..b51258420 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -960,12 +960,12 @@ def quantize_fp4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=None, + blocksize=64, compress_statistics=False, quant_storage=torch.uint8, ): - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -973,12 +973,12 @@ def quantize_nf4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=None, + blocksize=64, compress_statistics=False, quant_storage=torch.uint8, ): - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -986,7 +986,7 @@ def quantize_4bit( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=None, + blocksize=64, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, @@ -1014,8 +1014,8 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 input_shape = A.shape @@ -1067,10 +1067,10 @@ def dequantize_fp4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: Optional[int] = None, + blocksize: int = 64, ) -> torch.Tensor: - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1079,10 +1079,10 @@ def dequantize_nf4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: Optional[int] = None, + blocksize: int = 64, ) -> torch.Tensor: - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1091,7 +1091,7 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: Optional[int] = None, + blocksize: int = 64, quant_type="fp4", ) -> torch.Tensor: """Dequantizes a packed 4-bit quantized tensor. @@ -1121,8 +1121,8 @@ def dequantize_4bit( `torch.Tensor`: The dequantized tensor. """ - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 if quant_state is None: assert absmax is not None and out is not None From 660c25448edcff9f0f56368cc9ef04e91045d52c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Sat, 24 May 2025 21:57:22 +0530 Subject: [PATCH 016/102] Update functional.py --- bitsandbytes/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b51258420..2ae977e7a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -986,7 +986,7 @@ def quantize_4bit( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=64, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, From c692f4bc8f604f50a8a4f4409d373ed70c630364 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 27 May 2025 21:45:04 +0530 Subject: [PATCH 017/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 48dc75135..14878123a 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -3,7 +3,7 @@ from math import prod from typing import Optional -import torch +import torch from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr From 46f9800d9e9a361ecabf1051f99776fbfc73589d Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 27 May 2025 21:55:36 +0530 Subject: [PATCH 018/102] Update ops.py From 7823bac2c0c234c468392c219b29ed51dea8ca96 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:12:42 +0530 Subject: [PATCH 019/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1059 ++++++++++++++--------------- 1 file changed, 521 insertions(+), 538 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 14878123a..efdef2871 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,538 +1,521 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib, HIP_ENVIRONMENT - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional + +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) From d0ed1077d910acc4cd6f3ec4c57cf597931ff20c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:14:34 +0530 Subject: [PATCH 020/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1059 +++++++++++++++-------------- 1 file changed, 538 insertions(+), 521 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index efdef2871..14878123a 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,521 +1,538 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional + +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib, HIP_ENVIRONMENT + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) From af3aaf6a5d5ee90d713fbba875ab3cbd5137c619 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:17:20 +0530 Subject: [PATCH 021/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 14878123a..aa7c82f09 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -536,3 +536,4 @@ def _gemv_4bit_impl( ct.c_int32(blocksize), stream, ) + From d1e34a5dfe80aa95c42de7187800468d7a9e1b8a Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:18:53 +0530 Subject: [PATCH 022/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1058 ++++++++++++++--------------- 1 file changed, 520 insertions(+), 538 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index aa7c82f09..efdef2871 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,539 +1,521 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib, HIP_ENVIRONMENT - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) From b2b4df6d3046a166d6e177de2dbca26f1b0abcab Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:21:15 +0530 Subject: [PATCH 023/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1059 +++++++++++++++-------------- 1 file changed, 538 insertions(+), 521 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index efdef2871..14878123a 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,521 +1,538 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional + +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib, HIP_ENVIRONMENT + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) From 8863d0e3d55c73478926c9388080750be2e49690 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:22:01 +0530 Subject: [PATCH 024/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1059 ++++++++++++++--------------- 1 file changed, 521 insertions(+), 538 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 14878123a..efdef2871 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,538 +1,521 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib, HIP_ENVIRONMENT - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional + +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) From d1a5e8dec4e212e5c722d884809d5645c4772a1b Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:35:33 +0530 Subject: [PATCH 025/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index efdef2871..fd7b7b9a2 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -8,7 +8,7 @@ from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel -from ...cextension import lib +from ...cextension import lib, HIP_ENVIRONMENT @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -210,7 +210,12 @@ def _get_col_absmax( @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") n = A.numel() @@ -264,7 +269,11 @@ def _( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check( dtype in [torch.float16, torch.bfloat16, torch.float32], @@ -294,7 +303,11 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -372,7 +385,11 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], From 843ea338f968e06d586ac70c68e70b3a2c56c228 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:39:54 +0530 Subject: [PATCH 026/102] Update functional.py --- bitsandbytes/functional.py | 316 +++++++++++++++++-------------------- 1 file changed, 147 insertions(+), 169 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2ae977e7a..b0092ffd1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import lib, HIP_ENVIRONMENT +from .cextension import lib name2qmap = {} @@ -719,159 +719,152 @@ def __eq__(self, other): ) -def quantize_blockwise( - A: torch.Tensor, - code: Optional[torch.Tensor] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize=4096, - nested=False, -) -> tuple[torch.Tensor, QuantState]: - """Quantize a tensor in blocks of values. - The input tensor is quantized by dividing it into blocks of `blocksize` values. - The the absolute maximum value within these blocks is calculated for scaling - the non-linear quantization. - Args: - A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. - code (`torch.Tensor`, *optional*): - A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. - For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. - absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. - out (`torch.Tensor`, *optional*): A tensor to use to store the result. - blocksize (`int`, *optional*): - The size of the blocks. Defaults to 4096. - Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. - Raises: - ValueError: Raised when the input data type is not supported. - Returns: - `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. - - `torch.Tensor`: The quantized tensor. - - [`QuantState`]: The state object used to undo the quantization. - """ - - if code is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - - if HIP_ENVIRONMENT: - assert blocksize in [4096, 2048, 1024, 512, 256, 128] - else: - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - - _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( - A, - code.to(A.device), - blocksize, - ) - - if nested: - offset = _absmax.mean() - _absmax -= offset - qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) - quant_state = QuantState( - absmax=qabsmax, - code=code, - blocksize=blocksize, - dtype=A.dtype, - offset=offset, - state2=state2, - ) - else: - quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) - - # TODO(matthewdouglas): Deprecate out kwarg - out = out.copy_(_out) if out is not None else _out - - # TODO(matthewdouglas): Deprecate absmax kwarg - if absmax is not None: - quant_state.absmax = absmax.copy_(quant_state.absmax) - - return out, quant_state - - -def dequantize_blockwise( - A: torch.Tensor, - quant_state: Optional[QuantState] = None, - absmax: Optional[torch.Tensor] = None, - code: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 4096, - nested=False, -) -> torch.Tensor: - """Dequantize a tensor in blocks of values. - The input tensor is dequantized by dividing it into blocks of `blocksize` values. - The the absolute maximum value within these blocks is used for scaling - the non-linear dequantization. - Args: - A (`torch.Tensor`): The quantized input tensor. - quant_state ([`QuantState`], *optional*): - The quantization state as returned by [`quantize_blockwise`]. - Required if `absmax` is not provided. - absmax (`torch.Tensor`, *optional*): - A tensor containing the scaling values. - Required if `quant_state` is not provided and ignored otherwise. - code (`torch.Tensor`, *optional*): - A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. - For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. - Ignored when `quant_state` is provided. - out (`torch.Tensor`, *optional*): A tensor to use to store the result. - blocksize (`int`, *optional*): - The size of the blocks. Defaults to 4096. - Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - Ignored when `quant_state` is provided. - Raises: - ValueError: Raised when the input data type is not supported. - Returns: - `torch.Tensor`: - The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. - """ - - assert quant_state is not None or absmax is not None - if code is None and quant_state is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - - if quant_state is None: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) - - if HIP_ENVIRONMENT: - supported_blocksizes = [4096, 2048, 1024, 512, 256, 128] - else: - supported_blocksizes = [4096, 2048, 1024, 512, 256, 128, 64] - - if quant_state.blocksize not in supported_blocksizes: - raise ValueError( - f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" - ) - - absmax = quant_state.absmax - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - - if out is not None: - torch.ops.bitsandbytes.dequantize_blockwise.out( - A, - absmax, - code.to(A.device), - blocksize, - quant_state.dtype, - out=out, - ) - return out - - return torch.ops.bitsandbytes.dequantize_blockwise.default( - A, - absmax, - quant_state.code.to(A.device), - quant_state.blocksize, - quant_state.dtype, - ) +def quantize_blockwise( + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, +) -> tuple[torch.Tensor, QuantState]: + """Quantize a tensor in blocks of values. + + The input tensor is quantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is calculated for scaling + the non-linear quantization. + + Args: + A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. + + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. + - `torch.Tensor`: The quantized tensor. + - [`QuantState`]: The state object used to undo the quantization. + """ + + if code is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + + _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( + A, + code.to(A.device), + blocksize, + ) + + if nested: + offset = _absmax.mean() + _absmax -= offset + qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) + quant_state = QuantState( + absmax=qabsmax, + code=code, + blocksize=blocksize, + dtype=A.dtype, + offset=offset, + state2=state2, + ) + else: + quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) + + # TODO(matthewdouglas): Deprecate out kwarg + out = out.copy_(_out) if out is not None else _out + + # TODO(matthewdouglas): Deprecate absmax kwarg + if absmax is not None: + quant_state.absmax = absmax.copy_(quant_state.absmax) + + return out, quant_state + + +def dequantize_blockwise( + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, +) -> torch.Tensor: + """Dequantize a tensor in blocks of values. + + The input tensor is dequantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is used for scaling + the non-linear dequantization. + + Args: + A (`torch.Tensor`): The quantized input tensor. + quant_state ([`QuantState`], *optional*): + The quantization state as returned by [`quantize_blockwise`]. + Required if `absmax` is not provided. + absmax (`torch.Tensor`, *optional*): + A tensor containing the scaling values. + Required if `quant_state` is not provided and ignored otherwise. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + Ignored when `quant_state` is provided. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + Ignored when `quant_state` is provided. + + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + `torch.Tensor`: + The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. + """ + + assert quant_state is not None or absmax is not None + if code is None and quant_state is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + + if quant_state is None: + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + + absmax = quant_state.absmax + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + if out is not None: + torch.ops.bitsandbytes.dequantize_blockwise.out( + A, + absmax, + code.to(A.device), + blocksize, + quant_state.dtype, + out=out, + ) + return out + + return torch.ops.bitsandbytes.dequantize_blockwise.default( + A, + absmax, + quant_state.code.to(A.device), + quant_state.blocksize, + quant_state.dtype, + ) def get_4bit_type(typename, device=None, blocksize=64): @@ -964,8 +957,6 @@ def quantize_fp4( compress_statistics=False, quant_storage=torch.uint8, ): - if HIP_ENVIRONMENT: - blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -977,8 +968,6 @@ def quantize_nf4( compress_statistics=False, quant_storage=torch.uint8, ): - if HIP_ENVIRONMENT: - blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -1014,9 +1003,6 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ - if HIP_ENVIRONMENT: - blocksize = 128 - input_shape = A.shape _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( @@ -1069,8 +1055,6 @@ def dequantize_fp4( out: Optional[torch.Tensor] = None, blocksize: int = 64, ) -> torch.Tensor: - if HIP_ENVIRONMENT: - blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1081,8 +1065,6 @@ def dequantize_nf4( out: Optional[torch.Tensor] = None, blocksize: int = 64, ) -> torch.Tensor: - if HIP_ENVIRONMENT: - blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1120,10 +1102,6 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ - - if HIP_ENVIRONMENT: - blocksize = 128 - if quant_state is None: assert absmax is not None and out is not None From d6d2e5f32ffd30070c45f89704b8db20f600b577 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:57:37 +0530 Subject: [PATCH 027/102] Update functional.py --- bitsandbytes/functional.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b0092ffd1..959eeb33a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import lib +from .cextension import lib, HIP_ENVIRONMENT name2qmap = {} @@ -758,6 +758,11 @@ def quantize_blockwise( if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] + + if HIP_ENVIRONMENT: + assert blocksize in [4096, 2048, 1024, 512, 256, 128] + else: + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( A, @@ -839,6 +844,16 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + + if HIP_ENVIRONMENT: + supported_blocksizes = [4096, 2048, 1024, 512, 256, 128] + else: + supported_blocksizes = [4096, 2048, 1024, 512, 256, 128, 64] + + if quant_state.blocksize not in supported_blocksizes: + raise ValueError( + f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" + ) absmax = quant_state.absmax if quant_state.nested: @@ -957,6 +972,8 @@ def quantize_fp4( compress_statistics=False, quant_storage=torch.uint8, ): + if HIP_ENVIRONMENT: + blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -968,6 +985,8 @@ def quantize_nf4( compress_statistics=False, quant_storage=torch.uint8, ): + if HIP_ENVIRONMENT: + blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -1003,6 +1022,9 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ + if HIP_ENVIRONMENT: + blocksize = 128 + input_shape = A.shape _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( @@ -1055,6 +1077,8 @@ def dequantize_fp4( out: Optional[torch.Tensor] = None, blocksize: int = 64, ) -> torch.Tensor: + if HIP_ENVIRONMENT: + blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1065,6 +1089,8 @@ def dequantize_nf4( out: Optional[torch.Tensor] = None, blocksize: int = 64, ) -> torch.Tensor: + if HIP_ENVIRONMENT: + blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1102,6 +1128,10 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ + + if HIP_ENVIRONMENT: + blocksize = 128 + if quant_state is None: assert absmax is not None and out is not None From e3f9f21236ac76cac026eacf1da26f15e7a0ad1f Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 13:23:18 +0530 Subject: [PATCH 028/102] Update functional.py --- bitsandbytes/functional.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 959eeb33a..f4be0dc2f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1022,6 +1022,7 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ + if HIP_ENVIRONMENT: blocksize = 128 From bc0957daa57fc1364f914c2928bcfb730f97dc9d Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 17:26:33 +0530 Subject: [PATCH 029/102] Update test_ops.py --- tests/test_ops.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 4da1663f0..bb49c7dbb 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -5,6 +5,7 @@ import bitsandbytes from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter +from bitsandbytes.cextension import HIP_ENVIRONMENT class TestLLMInt8Ops: @@ -95,7 +96,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): if device == "cpu": if dtype != torch.float32: @@ -119,7 +120,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -145,7 +146,7 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -169,7 +170,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu": if quant_type != "nf4": @@ -206,7 +207,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu": pytest.xfail("CPU implementation is not available") From b8247ab109de936bcefb932b7d0ed996168f8445 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 17:34:22 +0530 Subject: [PATCH 030/102] Update test_functional.py --- tests/test_functional.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 96e77e4f4..3b9b53a24 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -8,6 +8,7 @@ import torch import bitsandbytes as bnb +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes import functional as F from tests.helpers import ( BOOLEAN_TUPLES, @@ -91,7 +92,7 @@ class Test8BitBlockwiseQuantizeFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) - @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) + @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128] if HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128, 64] ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): iters = 100 @@ -147,7 +148,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, @pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required") @pytest.mark.parametrize("hidden", [128]) - @pytest.mark.parametrize("blocksize", [4096, 16384]) + @pytest.mark.parametrize("blocksize", [4096] if HIP_ENVIRONMENT else [4096, 16384]) def test_blockwise_cpu_large(self, hidden, blocksize): diffs = [] reldiffs = [] @@ -1105,7 +1106,7 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize("blocksize", [128, 256, 512, 1024, 2048, 4096] if HIP_ENVIRONMENT else [64, 128, 256, 512, 1024, 2048, 4096]) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") From 531758a10835e68a10002eb825383a1a0608cb65 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 20:19:07 +0530 Subject: [PATCH 031/102] Update test_ops.py --- tests/test_ops.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index bb49c7dbb..a99d080b3 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -96,7 +96,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): if device == "cpu": if dtype != torch.float32: @@ -120,7 +120,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -146,7 +146,7 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -170,7 +170,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu": if quant_type != "nf4": @@ -207,7 +207,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu": pytest.xfail("CPU implementation is not available") From 6d7db8efa3a2d249434378ab09f3e9f5c0d72c26 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 20:29:23 +0530 Subject: [PATCH 032/102] Update test_functional.py --- tests/test_functional.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 3b9b53a24..4b62c2567 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -92,7 +92,7 @@ class Test8BitBlockwiseQuantizeFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) - @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128] if HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128, 64] ) + @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128] ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): iters = 100 @@ -148,7 +148,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, @pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required") @pytest.mark.parametrize("hidden", [128]) - @pytest.mark.parametrize("blocksize", [4096] if HIP_ENVIRONMENT else [4096, 16384]) + @pytest.mark.parametrize("blocksize", [4096, 16384] if not HIP_ENVIRONMENT else [4096]) def test_blockwise_cpu_large(self, hidden, blocksize): diffs = [] reldiffs = [] @@ -1106,7 +1106,7 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [128, 256, 512, 1024, 2048, 4096] if HIP_ENVIRONMENT else [64, 128, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096]) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1205,7 +1205,10 @@ def test_bench_4bit_dequant(self, quant_type): # torch.matmul(b, a.t()) # torch.cuda.synchronize() # print((time.time()-t0)/iters*1e6) - + + @pytest.mark.skipif( + HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" + ) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @@ -1369,6 +1372,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) + @pytest.mark.skipif( + HIP_ENVIRONMENT, reason="this test is not supported on ROCm with gfx90a architecture yet", + ) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): if device == "cpu" and storage_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") From 632e95b92d9feba37401ede69ad119017b50ae9d Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 21:05:21 +0530 Subject: [PATCH 033/102] Update test_functional.py --- tests/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 4b62c2567..7ad604d9f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1373,7 +1373,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) @pytest.mark.skipif( - HIP_ENVIRONMENT, reason="this test is not supported on ROCm with gfx90a architecture yet", + HIP_ENVIRONMENT, reason="this test is not supported on ROCm with gfx90a architecture yet" ) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): if device == "cpu" and storage_type != "nf4": From 90d9af2c387f05bcf4dc8d409a0ac3e4ef0d8e95 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 22:04:55 +0530 Subject: [PATCH 034/102] Update functional.py --- bitsandbytes/functional.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f4be0dc2f..2405a1985 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -968,12 +968,12 @@ def quantize_fp4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -981,12 +981,12 @@ def quantize_nf4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -994,7 +994,7 @@ def quantize_4bit( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, @@ -1023,8 +1023,8 @@ def quantize_4bit( - [`QuantState`]: The state object used to undo the quantization. """ - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 input_shape = A.shape @@ -1076,10 +1076,10 @@ def dequantize_fp4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1088,10 +1088,10 @@ def dequantize_nf4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1100,7 +1100,7 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, quant_type="fp4", ) -> torch.Tensor: """Dequantizes a packed 4-bit quantized tensor. @@ -1130,8 +1130,8 @@ def dequantize_4bit( `torch.Tensor`: The dequantized tensor. """ - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 if quant_state is None: assert absmax is not None and out is not None From 80048d89f249509db4c1fb482ce7694fcca3fdcb Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 01:38:52 +0530 Subject: [PATCH 035/102] Update functional.py --- bitsandbytes/functional.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2405a1985..03f6c323d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -758,11 +758,6 @@ def quantize_blockwise( if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - - if HIP_ENVIRONMENT: - assert blocksize in [4096, 2048, 1024, 512, 256, 128] - else: - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( A, @@ -844,16 +839,6 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) - - if HIP_ENVIRONMENT: - supported_blocksizes = [4096, 2048, 1024, 512, 256, 128] - else: - supported_blocksizes = [4096, 2048, 1024, 512, 256, 128, 64] - - if quant_state.blocksize not in supported_blocksizes: - raise ValueError( - f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" - ) absmax = quant_state.absmax if quant_state.nested: From e448ebbadf0313f429005001791c56d092992f01 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 02:40:56 +0530 Subject: [PATCH 036/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index fd7b7b9a2..f03d06599 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -303,11 +303,7 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -385,11 +381,7 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], From 048faa8ce60088fedc05474157c6356b14c3ee80 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 02:41:52 +0530 Subject: [PATCH 037/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index f03d06599..29dddc96e 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -381,7 +381,7 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], From c45e9d18c9fa55135cdaea92b68a4e8660d80bf6 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 02:44:51 +0530 Subject: [PATCH 038/102] Update test_functional.py --- tests/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 7ad604d9f..07c0d4964 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -148,7 +148,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, @pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required") @pytest.mark.parametrize("hidden", [128]) - @pytest.mark.parametrize("blocksize", [4096, 16384] if not HIP_ENVIRONMENT else [4096]) + @pytest.mark.parametrize("blocksize", [4096, 16384]) def test_blockwise_cpu_large(self, hidden, blocksize): diffs = [] reldiffs = [] From 47a491fb213b5286e0ed3cc9af773bf02f416f24 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 03:36:25 +0530 Subject: [PATCH 039/102] Update test_functional.py --- tests/test_functional.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 07c0d4964..2219efa2f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -8,7 +8,7 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH from bitsandbytes import functional as F from tests.helpers import ( BOOLEAN_TUPLES, @@ -1373,7 +1373,8 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) @pytest.mark.skipif( - HIP_ENVIRONMENT, reason="this test is not supported on ROCm with gfx90a architecture yet" + HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", + reason="this test is not supported on ROCm with gfx90a architecture yet", ) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): if device == "cpu" and storage_type != "nf4": From 86976bc22b04bc1415a13648582e453ce594700c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 03:38:53 +0530 Subject: [PATCH 040/102] Update cextension.py --- bitsandbytes/cextension.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index c8b02fb22..108aa0c9a 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -8,7 +8,7 @@ import torch from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR -from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple +from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch logger = logging.getLogger(__name__) @@ -298,6 +298,8 @@ def get_native_library() -> BNBNativeLibrary: return BNBNativeLibrary(dll) +ROCM_GPU_ARCH = get_rocm_gpu_arch() + try: if torch.version.hip: HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" From 98a142a7c7961fc58c0b90b388f080d56991b94c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 03:41:51 +0530 Subject: [PATCH 041/102] Update cuda_specs.py --- bitsandbytes/cuda_specs.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 64903cd49..da34dd608 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,6 +1,9 @@ import dataclasses +import logging +import re +import subprocess from functools import lru_cache -from typing import Optional +from typing import Optional, List, Tuple import torch @@ -73,3 +76,27 @@ def get_cuda_specs() -> Optional[CUDASpecs]: ) except Exception: return None + + +def get_rocm_gpu_arch() -> str: + """Get ROCm GPU architecture.""" + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) + if match: + return "gfx" + match.group(1) + else: + return "unknown" + else: + return "unknown" + except Exception as e: + logger.error(f"Could not detect ROCm GPU architecture: {e}") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm GPU architecture detection failed despite ROCm being available. + """, + ) + return "unknown" From 888fe46fee6fe59f377e4c4a3f19468a06094b91 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 03:59:01 +0530 Subject: [PATCH 042/102] Update cuda_specs.py --- bitsandbytes/cuda_specs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index da34dd608..61d03083c 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -3,7 +3,7 @@ import re import subprocess from functools import lru_cache -from typing import Optional, List, Tuple +from typing import Optional import torch From c9c52b56c1145d9ecd6ccfc4833799eae3bb2ccd Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 15:59:13 +0530 Subject: [PATCH 043/102] Update test_functional.py --- tests/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 2219efa2f..41ed7c984 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1141,7 +1141,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize")) + @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) def test_4bit_compressed_stats(self, device, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") From fc29586e8951cbe41aa5693ba0cd3ae3d25b05db Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:23:38 +0530 Subject: [PATCH 044/102] Update test_linear4bit.py --- tests/test_linear4bit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 67b61cb05..474a00a1b 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -7,6 +7,7 @@ import torch import bitsandbytes as bnb +from bitsandbytes.cextension import HIP_ENVIRONMENT from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer storage = { @@ -16,7 +17,7 @@ "float32": torch.float32, } - +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) From 53b8b1c580093e39d43d0018fa47abee6966442c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:27:39 +0530 Subject: [PATCH 045/102] Update test_cuda_setup_evaluator.py --- tests/test_cuda_setup_evaluator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 79406472e..1b2ea85db 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,6 +1,6 @@ import pytest -from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path from bitsandbytes.cuda_specs import CUDASpecs @@ -12,12 +12,12 @@ def cuda120_spec() -> CUDASpecs: cuda_version_tuple=(12, 0), ) - +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" - +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" From fe1fe7ccd0ab1c2a41da85d865e467de691cefac Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:34:11 +0530 Subject: [PATCH 046/102] Update test_functional.py --- tests/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 41ed7c984..5f5ee488c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -796,7 +796,7 @@ def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): A[:, outlier_cols] = 0 torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) - +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") class TestSpMMFunctional: @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) From e198824c5c9e23bb15d6eb2aa07a04f09e95446f Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:36:53 +0530 Subject: [PATCH 047/102] Update modules.py --- bitsandbytes/nn/modules.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 937084cf1..6b6490265 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -212,7 +212,7 @@ def __new__( data: Optional[torch.Tensor] = None, requires_grad=False, # quantized weights should be frozen by default quant_state: Optional[QuantState] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, compress_statistics: bool = True, quant_type: str = "fp4", quant_storage: torch.dtype = torch.uint8, @@ -221,7 +221,10 @@ def __new__( ) -> "Params4bit": if data is None: data = torch.empty(0) - + + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + self = torch.Tensor._make_subclass(cls, data, requires_grad) self.blocksize = blocksize self.compress_statistics = compress_statistics From dd58310df17b69c63a9a06186e7f6bb24c98a199 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:37:28 +0530 Subject: [PATCH 048/102] Update modules.py --- bitsandbytes/nn/modules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 6b6490265..2383f2c10 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -11,6 +11,7 @@ import torch.nn.functional as F import bitsandbytes as bnb +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.functional import QuantState from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import ( From 931bd70d868df8a663d32c3d4b410f72a45c1c3b Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:50:14 +0530 Subject: [PATCH 049/102] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 29dddc96e..fd7b7b9a2 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -303,7 +303,11 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -381,7 +385,11 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], From 9e62d466d226a62bd61e73afd676a694e1d13eac Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 18:56:05 +0530 Subject: [PATCH 050/102] Update test_linear4bit.py --- tests/test_linear4bit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 474a00a1b..c241a265d 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -184,7 +184,7 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): if device == "cpu": @@ -209,7 +209,7 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): if device == "cpu": @@ -241,7 +241,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): if device == "cpu": From 1f71562a9ba57dd209f844549ffc8ff98bebb06d Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 19:05:12 +0530 Subject: [PATCH 051/102] Update ops.py --- bitsandbytes/backends/cpu/ops.py | 93 ++++++++++++++++++++++++++------ 1 file changed, 76 insertions(+), 17 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index d5ab9aa88..f58be5d2a 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -103,16 +103,39 @@ def _( n = A.numel() - # TODO: Support when weight matrix is not divisible by blocksize - torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}") - - # Divide into blocks and normalize - blocks = A.reshape(-1, blocksize) - absmax = blocks.abs().max(dim=1).values.float() - scaled = blocks / absmax.unsqueeze(-1) - - # Quantize with the lookup table - quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) + blocks = n // blocksize + rem = n % blocksize + has_rem = rem > 0 + if has_rem: + blocks += 1 + + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + A_reshaped = A.reshape(n) + + if n >= blocksize: + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[:blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=1).values.float() + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[:blocks - has_rem].unsqueeze(-1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max().float() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + + # Quantize with the lookup table + quantized = torch.argmin(torch.abs(scaled_A.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) + else: + blocks = A.reshape(-1, blocksize) + absmax = blocks.abs().max(dim=1).values.float() + scaled_A = blocks / absmax.unsqueeze(-1) + + # Quantize with the lookup table + quantized = torch.argmin(torch.abs(scaled_A.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) + + if quantized.numel() % 2 == 1: + quantized = torch.cat([quantized, torch.zeros((1, 1), device=A.device, dtype=torch.uint8)]) # Pack two quantized values per byte packed = quantized[::2] << 4 | quantized[1::2] @@ -149,16 +172,52 @@ def _( upper = (A >> 4).to(torch.int64) lower = (A & 0x0F).to(torch.int64) - # Expand to blocks - blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) + # Calculate the total number of elements in the original tensor + n = 1 + for d in shape: + n *= d + + # Concatenate upper and lower nibbles + indices = torch.cat((upper, lower), dim=1).reshape(-1) + + if indices.numel() > n: + indices = indices[:n] + + blocks = n // blocksize + rem = n % blocksize + has_rem = rem > 0 + if has_rem: + blocks += 1 + + if has_rem: + out = torch.empty(shape, dtype=dtype, device=A.device) + out_reshaped = out.reshape(-1) + + padded_indices = torch.zeros(blocks * blocksize, dtype=indices.dtype, device=indices.device) + padded_indices[:n] = indices + blocks_data = padded_indices.reshape(-1, blocksize) + + # Dequantize full blocks + dequantized = _NF4_QUANT_TABLE[blocks_data] + + # Apply scales to full blocks + out_reshaped[:n - rem] = ( + dequantized[:blocks - 1].reshape(-1, blocksize) * absmax[:blocks - 1].view(-1, 1) + ).reshape(-1) + + # Apply scale to remainder block + out_reshaped[n - rem:] = dequantized[blocks - 1, :rem] * absmax[-1] + else: + # Expand to blocks + blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) - # Dequantize - blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None] + # Dequantize + blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None] - # Reshape to original shape - blocks = blocks.reshape(-1, *shape[1:]) + # Reshape to original shape + out = blocks.reshape(-1, *shape[1:]) - return blocks.to(dtype) + return out.to(dtype) @register_kernel("bitsandbytes::gemv_4bit", "cpu") From eac7632e28043caad307cf2b5e1ff61fc9cbfe12 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 20:46:28 +0530 Subject: [PATCH 052/102] Update ops.py --- bitsandbytes/backends/cpu/ops.py | 93 ++++++-------------------------- 1 file changed, 17 insertions(+), 76 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index f58be5d2a..d5ab9aa88 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -103,39 +103,16 @@ def _( n = A.numel() - blocks = n // blocksize - rem = n % blocksize - has_rem = rem > 0 - if has_rem: - blocks += 1 - - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - A_reshaped = A.reshape(n) - - if n >= blocksize: - A_com = A_reshaped[: n - rem] - A_com_reshaped = A_com.reshape(n // blocksize, blocksize) - absmax[:blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=1).values.float() - scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[:blocks - has_rem].unsqueeze(-1)), -1, 1) - scaled_A = scaled_A.reshape(-1) - - if has_rem: - absmax[-1] = torch.abs(A_reshaped[n - rem :]).max().float() - scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) - scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) - - # Quantize with the lookup table - quantized = torch.argmin(torch.abs(scaled_A.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) - else: - blocks = A.reshape(-1, blocksize) - absmax = blocks.abs().max(dim=1).values.float() - scaled_A = blocks / absmax.unsqueeze(-1) - - # Quantize with the lookup table - quantized = torch.argmin(torch.abs(scaled_A.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) - - if quantized.numel() % 2 == 1: - quantized = torch.cat([quantized, torch.zeros((1, 1), device=A.device, dtype=torch.uint8)]) + # TODO: Support when weight matrix is not divisible by blocksize + torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}") + + # Divide into blocks and normalize + blocks = A.reshape(-1, blocksize) + absmax = blocks.abs().max(dim=1).values.float() + scaled = blocks / absmax.unsqueeze(-1) + + # Quantize with the lookup table + quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) # Pack two quantized values per byte packed = quantized[::2] << 4 | quantized[1::2] @@ -172,52 +149,16 @@ def _( upper = (A >> 4).to(torch.int64) lower = (A & 0x0F).to(torch.int64) - # Calculate the total number of elements in the original tensor - n = 1 - for d in shape: - n *= d - - # Concatenate upper and lower nibbles - indices = torch.cat((upper, lower), dim=1).reshape(-1) - - if indices.numel() > n: - indices = indices[:n] - - blocks = n // blocksize - rem = n % blocksize - has_rem = rem > 0 - if has_rem: - blocks += 1 - - if has_rem: - out = torch.empty(shape, dtype=dtype, device=A.device) - out_reshaped = out.reshape(-1) - - padded_indices = torch.zeros(blocks * blocksize, dtype=indices.dtype, device=indices.device) - padded_indices[:n] = indices - blocks_data = padded_indices.reshape(-1, blocksize) - - # Dequantize full blocks - dequantized = _NF4_QUANT_TABLE[blocks_data] - - # Apply scales to full blocks - out_reshaped[:n - rem] = ( - dequantized[:blocks - 1].reshape(-1, blocksize) * absmax[:blocks - 1].view(-1, 1) - ).reshape(-1) - - # Apply scale to remainder block - out_reshaped[n - rem:] = dequantized[blocks - 1, :rem] * absmax[-1] - else: - # Expand to blocks - blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) + # Expand to blocks + blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) - # Dequantize - blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None] + # Dequantize + blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None] - # Reshape to original shape - out = blocks.reshape(-1, *shape[1:]) + # Reshape to original shape + blocks = blocks.reshape(-1, *shape[1:]) - return out.to(dtype) + return blocks.to(dtype) @register_kernel("bitsandbytes::gemv_4bit", "cpu") From 66dcfc407f59052fa9d5359cdebf619886100033 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 21:16:02 +0530 Subject: [PATCH 053/102] Update test_linear4bit.py --- tests/test_linear4bit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index c241a265d..1b7a7722c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -17,7 +17,6 @@ "float32": torch.float32, } -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) From b96905d26c63355884e7decc65297591e108679d Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 21:17:02 +0530 Subject: [PATCH 054/102] Update test_linear4bit.py From ef31c362e22b201551605bc6d808026ea33da59c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 23:55:14 +0530 Subject: [PATCH 055/102] Update python-package.yml --- .github/workflows/python-package.yml | 643 ++++++++++++++------------- 1 file changed, 343 insertions(+), 300 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index fbaa27d56..10daf0f79 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -1,303 +1,346 @@ -name: Python package - -on: - push: {} - pull_request: - branches: [main] - paths: - - ".github/workflows/python-package.yml" - - "bitsandbytes/**" - - "csrc/**" - - "include/**" - - "tests/**" - - "CMakeLists.txt" - - "requirements*.txt" - - "setup.py" - - "pyproject.toml" - release: - types: [published] - workflow_dispatch: {} # Allow manual trigger - workflow_call: {} # Allow triggering from other worfkflows - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - ## - # This job matrix builds the non-CUDA versions of the libraries for all supported platforms. - ## - build-shared-libs: - strategy: - matrix: - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - - os: macos-latest - arch: arm64 - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Setup MSVC - if: startsWith(matrix.os, 'windows') - uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl - - name: Build C++ - run: bash .github/scripts/build-cpu.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_${{ matrix.os }}_${{ matrix.arch }} - path: output/* - retention-days: 7 - ## - # This job matrix builds the CUDA versions of the libraries for platforms that support CUDA (Linux x64/aarch64 + Windows x64) - ## - build-shared-libs-cuda: - strategy: - fail-fast: false - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - cuda_version: - ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - # Windows: We install Cuda on the agent (slow) - - uses: Jimver/cuda-toolkit@v0.2.22 - if: startsWith(matrix.os, 'windows') - id: cuda-toolkit - with: - cuda: ${{ matrix.cuda_version }} - method: "network" - sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]' - linux-local-args: '["--toolkit"]' - use-github-cache: false - - name: Setup MSVC - if: startsWith(matrix.os, 'windows') - uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl - - name: Build C++ - run: bash .github/scripts/build-cuda.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - cuda_version: ${{ matrix.cuda_version }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} - path: output/* - retention-days: 7 - - build-wheels: - needs: - - build-shared-libs - - build-shared-libs-cuda - strategy: - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - - os: macos-latest - arch: arm64 - # The specific Python version is irrelevant in this context as we are only packaging non-C extension - # code. This ensures compatibility across Python versions, including Python 3.9, as compatibility is - # dictated by the packaged code itself, not the Python version used for packaging. - python-version: ["3.10"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Download build artifacts - uses: actions/download-artifact@v4 - with: - merge-multiple: true - pattern: "shared_library*_${{ matrix.os }}_${{ matrix.arch }}*" - path: output/ - - name: Copy correct platform shared library - shell: bash - run: | - ls -lR output/ - cp output/${{ matrix.os }}/${{ matrix.arch }}/* bitsandbytes/ - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: pip - - run: pip install build wheel - - run: python -m build . - - name: Determine and Set Platform Tag, then Tag Wheel - shell: bash - run: | - PLATFORM_TAG=$(python .github/scripts/set_platform_tag.py "${{ matrix.arch }}") - echo "PLATFORM_TAG=$PLATFORM_TAG" - wheel tags --remove --abi-tag=none --python-tag=py3 --platform-tag=$PLATFORM_TAG dist/bitsandbytes-*.whl - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} - path: dist/bitsandbytes-*.whl - retention-days: 7 - - upload-pre-release-wheels: - name: Create release and upload artifacts - runs-on: ubuntu-latest - if: github.ref_name == 'main' - permissions: - contents: write - needs: - - build-wheels - steps: - - name: Download and rename artifacts - uses: actions/download-artifact@v4 - with: - path: tmp/ - pattern: "bdist_wheel_*" - merge-multiple: true +name: Python package - - name: Inspect tmp directory after downloading artifacts - run: ls -alFR tmp/ +on: + push: {} + pull_request: + branches: [main] + paths: + - ".github/workflows/python-package.yml" + - "bitsandbytes/**" + - "csrc/**" + - "include/**" + - "tests/**" + - "CMakeLists.txt" + - "requirements*.txt" + - "setup.py" + - "pyproject.toml" + release: + types: [published] + workflow_dispatch: {} # Allow manual trigger + workflow_call: {} # Allow triggering from other worfkflows - - name: Move and rename wheel files with pattern replacement - run: | - mkdir -p wheels/ - - # The whole point of the continuous release is to have a stable download link and the only way to have a PEP 440–compliant wheel name - # is to use a stable placeholder version. Otherwise, pip won't let you install the wheel. The cool thing is that we can now install the - # wheel directly from the GH pre-release which gets updated continuously, e.g. - # `pip install https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl` - STABLE_PLACEHOLDER_VERSION="1.33.7.preview" - - # exclude macos wheels for now - find tmp/ -type f -name '*.whl' ! -name '*macos*' -print0 | while IFS= read -r -d '' wheel; do - wheel_filename=$(basename "$wheel") - - # Strip off the original version - rest=${wheel_filename#bitsandbytes-*-} - new_name="bitsandbytes-${STABLE_PLACEHOLDER_VERSION}-${rest}" - - echo "Renaming $wheel_filename → $new_name" - mv "$wheel" "wheels/${new_name}" - done - - - name: Inspect wheels directory after renaming files - run: ls -alFR wheels/ +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true - - name: Delete old pre-release (if exists) - run: | - gh release delete continuous-release_main --cleanup-tag -y || true - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Generate pip install commands for release body - run: | - cat > body.md << 'ENDOFMARKDOWN' - ## Latest `main` Wheel Pre-release - - This pre-release contains the latest development wheels for all supported platforms, rebuilt automatically on every commit to the `main` branch. - - **How to install:** - Pick the correct command for your platform and run it in your terminal: - - ENDOFMARKDOWN - - for whl in wheels/*.whl; do - fname=$(basename "$whl") - url="https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/$fname" - echo "\`\`\`sh" >> body.md - echo "pip install $url" >> body.md - echo "\`\`\`" >> body.md - echo "" >> body.md - done - - cat >> body.md << 'ENDOFMARKDOWN' - > **Note:** - > These wheels are updated automatically with every commit to `main` and become available as soon as the [python-package.yml](.github/workflows/python-package.yml) workflow finishes. - ENDOFMARKDOWN - - # for debugging: - cat body.md - - - name: Create new pre-release and upload artifacts - uses: softprops/action-gh-release@v2.2.1 - with: - files: wheels/*.whl - prerelease: true - name: Latest `main` wheel - body_path: body.md - tag_name: continuous-release_main - make_latest: false - draft: false - target_commitish: ${{ github.sha }} - - audit-wheels: - needs: build-wheels - strategy: - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - runs-on: ${{ matrix.os }} - env: - PIP_DISABLE_PIP_VERSION_CHECK: 1 - steps: - - uses: actions/checkout@v4 - - name: Download wheel - uses: actions/download-artifact@v4 - with: - name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} - path: wheels/ - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - run: pip install auditwheel - - run: python ./.github/scripts/auditwheel_show.py wheels/* | tee $GITHUB_STEP_SUMMARY - - publish-wheels: - name: Publish wheels to PyPI - needs: [build-wheels, audit-wheels] - runs-on: ubuntu-latest - if: | - github.repository == 'bitsandbytes-foundation/bitsandbytes' - && github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - environment: - name: release - url: https://pypi.org/p/bitsandbytes - permissions: - id-token: write - steps: - - name: Download distribution artifacts - uses: actions/download-artifact@v4 - with: - path: dist/ - pattern: "bdist_wheel_*" - merge-multiple: true - - - name: Remove macOS wheels - run: rm dist/*macos* - - - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - print-hash: true +jobs: + ## + # This job matrix builds the non-CUDA versions of the libraries for all supported platforms. + ## + build-shared-libs: + strategy: + matrix: + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + - os: macos-latest + arch: arm64 + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl + - name: Build C++ + run: bash .github/scripts/build-cpu.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_${{ matrix.os }}_${{ matrix.arch }} + path: output/* + retention-days: 7 + ## + # This job matrix builds the CUDA versions of the libraries for platforms that support CUDA (Linux x64/aarch64 + Windows x64) + ## + build-shared-libs-cuda: + strategy: + fail-fast: false + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + cuda_version: + ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + # Windows: We install Cuda on the agent (slow) + - uses: Jimver/cuda-toolkit@v0.2.22 + if: startsWith(matrix.os, 'windows') + id: cuda-toolkit + with: + cuda: ${{ matrix.cuda_version }} + method: "network" + sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]' + linux-local-args: '["--toolkit"]' + use-github-cache: false + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl + - name: Build C++ + run: bash .github/scripts/build-cuda.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + cuda_version: ${{ matrix.cuda_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} + path: output/* + retention-days: 7 + build-shared-libs-rocm: + strategy: + matrix: + os: [ubuntu-22.04] + arch: [x86_64] + rocm_version: + ["6.1.2", "6.2.4", "6.3.2"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Docker multiarch + uses: docker/setup-qemu-action@v3 + - name: Clean up disk space + run: | + sudo rm -rf \ + /usr/share/dotnet \ + /opt/ghc \ + "/usr/local/share/boost" \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/hostedtoolcache \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/swift + - name: Build C++ + run: bash .github/scripts/build-rocm.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + rocm_version: ${{ matrix.rocm_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} + path: output/* + retention-days: 7 + build-wheels: + needs: + - build-shared-libs + - build-shared-libs-cuda + - build-shared-libs-rocm + strategy: + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + - os: macos-latest + arch: arm64 + # The specific Python version is irrelevant in this context as we are only packaging non-C extension + # code. This ensures compatibility across Python versions, including Python 3.9, as compatibility is + # dictated by the packaged code itself, not the Python version used for packaging. + python-version: ["3.10"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + merge-multiple: true + pattern: "shared_library*_${{ matrix.os }}_${{ matrix.arch }}*" + path: output/ + - name: Copy correct platform shared library + shell: bash + run: | + ls -lR output/ + cp output/${{ matrix.os }}/${{ matrix.arch }}/* bitsandbytes/ + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - run: pip install build wheel + - run: python -m build . + - name: Determine and Set Platform Tag, then Tag Wheel + shell: bash + run: | + PLATFORM_TAG=$(python .github/scripts/set_platform_tag.py "${{ matrix.arch }}") + echo "PLATFORM_TAG=$PLATFORM_TAG" + wheel tags --remove --abi-tag=none --python-tag=py3 --platform-tag=$PLATFORM_TAG dist/bitsandbytes-*.whl + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} + path: dist/bitsandbytes-*.whl + retention-days: 7 + + upload-pre-release-wheels: + name: Create release and upload artifacts + runs-on: ubuntu-latest + if: github.ref_name == 'main' + permissions: + contents: write + needs: + - build-wheels + steps: + - name: Download and rename artifacts + uses: actions/download-artifact@v4 + with: + path: tmp/ + pattern: "bdist_wheel_*" + merge-multiple: true + + - name: Inspect tmp directory after downloading artifacts + run: ls -alFR tmp/ + + - name: Move and rename wheel files with pattern replacement + run: | + mkdir -p wheels/ + + # The whole point of the continuous release is to have a stable download link and the only way to have a PEP 440–compliant wheel name + # is to use a stable placeholder version. Otherwise, pip won't let you install the wheel. The cool thing is that we can now install the + # wheel directly from the GH pre-release which gets updated continuously, e.g. + # `pip install https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl` + STABLE_PLACEHOLDER_VERSION="1.33.7.preview" + + # exclude macos wheels for now + find tmp/ -type f -name '*.whl' ! -name '*macos*' -print0 | while IFS= read -r -d '' wheel; do + wheel_filename=$(basename "$wheel") + + # Strip off the original version + rest=${wheel_filename#bitsandbytes-*-} + new_name="bitsandbytes-${STABLE_PLACEHOLDER_VERSION}-${rest}" + + echo "Renaming $wheel_filename → $new_name" + mv "$wheel" "wheels/${new_name}" + done + + - name: Inspect wheels directory after renaming files + run: ls -alFR wheels/ + + - name: Delete old pre-release (if exists) + run: | + gh release delete continuous-release_main --cleanup-tag -y || true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Generate pip install commands for release body + run: | + cat > body.md << 'ENDOFMARKDOWN' + ## Latest `main` Wheel Pre-release + + This pre-release contains the latest development wheels for all supported platforms, rebuilt automatically on every commit to the `main` branch. + + **How to install:** + Pick the correct command for your platform and run it in your terminal: + + ENDOFMARKDOWN + + for whl in wheels/*.whl; do + fname=$(basename "$whl") + url="https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/$fname" + echo "\`\`\`sh" >> body.md + echo "pip install $url" >> body.md + echo "\`\`\`" >> body.md + echo "" >> body.md + done + + cat >> body.md << 'ENDOFMARKDOWN' + > **Note:** + > These wheels are updated automatically with every commit to `main` and become available as soon as the [python-package.yml](.github/workflows/python-package.yml) workflow finishes. + ENDOFMARKDOWN + + # for debugging: + cat body.md + + - name: Create new pre-release and upload artifacts + uses: softprops/action-gh-release@v2.2.1 + with: + files: wheels/*.whl + prerelease: true + name: Latest `main` wheel + body_path: body.md + tag_name: continuous-release_main + make_latest: false + draft: false + target_commitish: ${{ github.sha }} + + audit-wheels: + needs: build-wheels + strategy: + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + runs-on: ${{ matrix.os }} + env: + PIP_DISABLE_PIP_VERSION_CHECK: 1 + steps: + - uses: actions/checkout@v4 + - name: Download wheel + uses: actions/download-artifact@v4 + with: + name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} + path: wheels/ + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: pip install auditwheel + - run: python ./.github/scripts/auditwheel_show.py wheels/* | tee $GITHUB_STEP_SUMMARY + + publish-wheels: + name: Publish wheels to PyPI + needs: [build-wheels, audit-wheels] + runs-on: ubuntu-latest + if: | + github.repository == 'bitsandbytes-foundation/bitsandbytes' + && github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + environment: + name: release + url: https://pypi.org/p/bitsandbytes + permissions: + id-token: write + steps: + - name: Download distribution artifacts + uses: actions/download-artifact@v4 + with: + path: dist/ + pattern: "bdist_wheel_*" + merge-multiple: true + + - name: Remove macOS wheels + run: rm dist/*macos* + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + print-hash: true From e1435f01776137c3a253228b4234a23535532161 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 23:57:25 +0530 Subject: [PATCH 056/102] Update python-package.yml --- .github/workflows/python-package.yml | 643 +++++++++++++-------------- 1 file changed, 300 insertions(+), 343 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 10daf0f79..fbaa27d56 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -1,346 +1,303 @@ -name: Python package +name: Python package + +on: + push: {} + pull_request: + branches: [main] + paths: + - ".github/workflows/python-package.yml" + - "bitsandbytes/**" + - "csrc/**" + - "include/**" + - "tests/**" + - "CMakeLists.txt" + - "requirements*.txt" + - "setup.py" + - "pyproject.toml" + release: + types: [published] + workflow_dispatch: {} # Allow manual trigger + workflow_call: {} # Allow triggering from other worfkflows + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + ## + # This job matrix builds the non-CUDA versions of the libraries for all supported platforms. + ## + build-shared-libs: + strategy: + matrix: + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + - os: macos-latest + arch: arm64 + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl + - name: Build C++ + run: bash .github/scripts/build-cpu.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_${{ matrix.os }}_${{ matrix.arch }} + path: output/* + retention-days: 7 + ## + # This job matrix builds the CUDA versions of the libraries for platforms that support CUDA (Linux x64/aarch64 + Windows x64) + ## + build-shared-libs-cuda: + strategy: + fail-fast: false + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + cuda_version: + ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + # Windows: We install Cuda on the agent (slow) + - uses: Jimver/cuda-toolkit@v0.2.22 + if: startsWith(matrix.os, 'windows') + id: cuda-toolkit + with: + cuda: ${{ matrix.cuda_version }} + method: "network" + sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]' + linux-local-args: '["--toolkit"]' + use-github-cache: false + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl + - name: Build C++ + run: bash .github/scripts/build-cuda.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + cuda_version: ${{ matrix.cuda_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} + path: output/* + retention-days: 7 + + build-wheels: + needs: + - build-shared-libs + - build-shared-libs-cuda + strategy: + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + - os: macos-latest + arch: arm64 + # The specific Python version is irrelevant in this context as we are only packaging non-C extension + # code. This ensures compatibility across Python versions, including Python 3.9, as compatibility is + # dictated by the packaged code itself, not the Python version used for packaging. + python-version: ["3.10"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + merge-multiple: true + pattern: "shared_library*_${{ matrix.os }}_${{ matrix.arch }}*" + path: output/ + - name: Copy correct platform shared library + shell: bash + run: | + ls -lR output/ + cp output/${{ matrix.os }}/${{ matrix.arch }}/* bitsandbytes/ + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - run: pip install build wheel + - run: python -m build . + - name: Determine and Set Platform Tag, then Tag Wheel + shell: bash + run: | + PLATFORM_TAG=$(python .github/scripts/set_platform_tag.py "${{ matrix.arch }}") + echo "PLATFORM_TAG=$PLATFORM_TAG" + wheel tags --remove --abi-tag=none --python-tag=py3 --platform-tag=$PLATFORM_TAG dist/bitsandbytes-*.whl + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} + path: dist/bitsandbytes-*.whl + retention-days: 7 + + upload-pre-release-wheels: + name: Create release and upload artifacts + runs-on: ubuntu-latest + if: github.ref_name == 'main' + permissions: + contents: write + needs: + - build-wheels + steps: + - name: Download and rename artifacts + uses: actions/download-artifact@v4 + with: + path: tmp/ + pattern: "bdist_wheel_*" + merge-multiple: true -on: - push: {} - pull_request: - branches: [main] - paths: - - ".github/workflows/python-package.yml" - - "bitsandbytes/**" - - "csrc/**" - - "include/**" - - "tests/**" - - "CMakeLists.txt" - - "requirements*.txt" - - "setup.py" - - "pyproject.toml" - release: - types: [published] - workflow_dispatch: {} # Allow manual trigger - workflow_call: {} # Allow triggering from other worfkflows + - name: Inspect tmp directory after downloading artifacts + run: ls -alFR tmp/ -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true + - name: Move and rename wheel files with pattern replacement + run: | + mkdir -p wheels/ + + # The whole point of the continuous release is to have a stable download link and the only way to have a PEP 440–compliant wheel name + # is to use a stable placeholder version. Otherwise, pip won't let you install the wheel. The cool thing is that we can now install the + # wheel directly from the GH pre-release which gets updated continuously, e.g. + # `pip install https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl` + STABLE_PLACEHOLDER_VERSION="1.33.7.preview" + + # exclude macos wheels for now + find tmp/ -type f -name '*.whl' ! -name '*macos*' -print0 | while IFS= read -r -d '' wheel; do + wheel_filename=$(basename "$wheel") + + # Strip off the original version + rest=${wheel_filename#bitsandbytes-*-} + new_name="bitsandbytes-${STABLE_PLACEHOLDER_VERSION}-${rest}" + + echo "Renaming $wheel_filename → $new_name" + mv "$wheel" "wheels/${new_name}" + done + + - name: Inspect wheels directory after renaming files + run: ls -alFR wheels/ -jobs: - ## - # This job matrix builds the non-CUDA versions of the libraries for all supported platforms. - ## - build-shared-libs: - strategy: - matrix: - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - - os: macos-latest - arch: arm64 - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Setup MSVC - if: startsWith(matrix.os, 'windows') - uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl - - name: Build C++ - run: bash .github/scripts/build-cpu.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_${{ matrix.os }}_${{ matrix.arch }} - path: output/* - retention-days: 7 - ## - # This job matrix builds the CUDA versions of the libraries for platforms that support CUDA (Linux x64/aarch64 + Windows x64) - ## - build-shared-libs-cuda: - strategy: - fail-fast: false - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - cuda_version: - ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - # Windows: We install Cuda on the agent (slow) - - uses: Jimver/cuda-toolkit@v0.2.22 - if: startsWith(matrix.os, 'windows') - id: cuda-toolkit - with: - cuda: ${{ matrix.cuda_version }} - method: "network" - sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]' - linux-local-args: '["--toolkit"]' - use-github-cache: false - - name: Setup MSVC - if: startsWith(matrix.os, 'windows') - uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl - - name: Build C++ - run: bash .github/scripts/build-cuda.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - cuda_version: ${{ matrix.cuda_version }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} - path: output/* - retention-days: 7 - build-shared-libs-rocm: - strategy: - matrix: - os: [ubuntu-22.04] - arch: [x86_64] - rocm_version: - ["6.1.2", "6.2.4", "6.3.2"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Set up Docker multiarch - uses: docker/setup-qemu-action@v3 - - name: Clean up disk space - run: | - sudo rm -rf \ - /usr/share/dotnet \ - /opt/ghc \ - "/usr/local/share/boost" \ - "$AGENT_TOOLSDIRECTORY" \ - /opt/hostedtoolcache \ - /opt/google/chrome \ - /opt/microsoft/msedge \ - /opt/microsoft/powershell \ - /opt/pipx \ - /usr/lib/mono \ - /usr/local/julia* \ - /usr/local/lib/android \ - /usr/local/lib/node_modules \ - /usr/local/share/chromium \ - /usr/local/share/powershell \ - /usr/share/swift - - name: Build C++ - run: bash .github/scripts/build-rocm.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - rocm_version: ${{ matrix.rocm_version }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} - path: output/* - retention-days: 7 - build-wheels: - needs: - - build-shared-libs - - build-shared-libs-cuda - - build-shared-libs-rocm - strategy: - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - - os: macos-latest - arch: arm64 - # The specific Python version is irrelevant in this context as we are only packaging non-C extension - # code. This ensures compatibility across Python versions, including Python 3.9, as compatibility is - # dictated by the packaged code itself, not the Python version used for packaging. - python-version: ["3.10"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Download build artifacts - uses: actions/download-artifact@v4 - with: - merge-multiple: true - pattern: "shared_library*_${{ matrix.os }}_${{ matrix.arch }}*" - path: output/ - - name: Copy correct platform shared library - shell: bash - run: | - ls -lR output/ - cp output/${{ matrix.os }}/${{ matrix.arch }}/* bitsandbytes/ - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: pip - - run: pip install build wheel - - run: python -m build . - - name: Determine and Set Platform Tag, then Tag Wheel - shell: bash - run: | - PLATFORM_TAG=$(python .github/scripts/set_platform_tag.py "${{ matrix.arch }}") - echo "PLATFORM_TAG=$PLATFORM_TAG" - wheel tags --remove --abi-tag=none --python-tag=py3 --platform-tag=$PLATFORM_TAG dist/bitsandbytes-*.whl - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} - path: dist/bitsandbytes-*.whl - retention-days: 7 - - upload-pre-release-wheels: - name: Create release and upload artifacts - runs-on: ubuntu-latest - if: github.ref_name == 'main' - permissions: - contents: write - needs: - - build-wheels - steps: - - name: Download and rename artifacts - uses: actions/download-artifact@v4 - with: - path: tmp/ - pattern: "bdist_wheel_*" - merge-multiple: true - - - name: Inspect tmp directory after downloading artifacts - run: ls -alFR tmp/ - - - name: Move and rename wheel files with pattern replacement - run: | - mkdir -p wheels/ - - # The whole point of the continuous release is to have a stable download link and the only way to have a PEP 440–compliant wheel name - # is to use a stable placeholder version. Otherwise, pip won't let you install the wheel. The cool thing is that we can now install the - # wheel directly from the GH pre-release which gets updated continuously, e.g. - # `pip install https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl` - STABLE_PLACEHOLDER_VERSION="1.33.7.preview" - - # exclude macos wheels for now - find tmp/ -type f -name '*.whl' ! -name '*macos*' -print0 | while IFS= read -r -d '' wheel; do - wheel_filename=$(basename "$wheel") - - # Strip off the original version - rest=${wheel_filename#bitsandbytes-*-} - new_name="bitsandbytes-${STABLE_PLACEHOLDER_VERSION}-${rest}" - - echo "Renaming $wheel_filename → $new_name" - mv "$wheel" "wheels/${new_name}" - done - - - name: Inspect wheels directory after renaming files - run: ls -alFR wheels/ - - - name: Delete old pre-release (if exists) - run: | - gh release delete continuous-release_main --cleanup-tag -y || true - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Generate pip install commands for release body - run: | - cat > body.md << 'ENDOFMARKDOWN' - ## Latest `main` Wheel Pre-release - - This pre-release contains the latest development wheels for all supported platforms, rebuilt automatically on every commit to the `main` branch. - - **How to install:** - Pick the correct command for your platform and run it in your terminal: - - ENDOFMARKDOWN - - for whl in wheels/*.whl; do - fname=$(basename "$whl") - url="https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/$fname" - echo "\`\`\`sh" >> body.md - echo "pip install $url" >> body.md - echo "\`\`\`" >> body.md - echo "" >> body.md - done - - cat >> body.md << 'ENDOFMARKDOWN' - > **Note:** - > These wheels are updated automatically with every commit to `main` and become available as soon as the [python-package.yml](.github/workflows/python-package.yml) workflow finishes. - ENDOFMARKDOWN - - # for debugging: - cat body.md - - - name: Create new pre-release and upload artifacts - uses: softprops/action-gh-release@v2.2.1 - with: - files: wheels/*.whl - prerelease: true - name: Latest `main` wheel - body_path: body.md - tag_name: continuous-release_main - make_latest: false - draft: false - target_commitish: ${{ github.sha }} - - audit-wheels: - needs: build-wheels - strategy: - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - runs-on: ${{ matrix.os }} - env: - PIP_DISABLE_PIP_VERSION_CHECK: 1 - steps: - - uses: actions/checkout@v4 - - name: Download wheel - uses: actions/download-artifact@v4 - with: - name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} - path: wheels/ - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - run: pip install auditwheel - - run: python ./.github/scripts/auditwheel_show.py wheels/* | tee $GITHUB_STEP_SUMMARY - - publish-wheels: - name: Publish wheels to PyPI - needs: [build-wheels, audit-wheels] - runs-on: ubuntu-latest - if: | - github.repository == 'bitsandbytes-foundation/bitsandbytes' - && github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - environment: - name: release - url: https://pypi.org/p/bitsandbytes - permissions: - id-token: write - steps: - - name: Download distribution artifacts - uses: actions/download-artifact@v4 - with: - path: dist/ - pattern: "bdist_wheel_*" - merge-multiple: true - - - name: Remove macOS wheels - run: rm dist/*macos* - - - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - print-hash: true + - name: Delete old pre-release (if exists) + run: | + gh release delete continuous-release_main --cleanup-tag -y || true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Generate pip install commands for release body + run: | + cat > body.md << 'ENDOFMARKDOWN' + ## Latest `main` Wheel Pre-release + + This pre-release contains the latest development wheels for all supported platforms, rebuilt automatically on every commit to the `main` branch. + + **How to install:** + Pick the correct command for your platform and run it in your terminal: + + ENDOFMARKDOWN + + for whl in wheels/*.whl; do + fname=$(basename "$whl") + url="https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/$fname" + echo "\`\`\`sh" >> body.md + echo "pip install $url" >> body.md + echo "\`\`\`" >> body.md + echo "" >> body.md + done + + cat >> body.md << 'ENDOFMARKDOWN' + > **Note:** + > These wheels are updated automatically with every commit to `main` and become available as soon as the [python-package.yml](.github/workflows/python-package.yml) workflow finishes. + ENDOFMARKDOWN + + # for debugging: + cat body.md + + - name: Create new pre-release and upload artifacts + uses: softprops/action-gh-release@v2.2.1 + with: + files: wheels/*.whl + prerelease: true + name: Latest `main` wheel + body_path: body.md + tag_name: continuous-release_main + make_latest: false + draft: false + target_commitish: ${{ github.sha }} + + audit-wheels: + needs: build-wheels + strategy: + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + runs-on: ${{ matrix.os }} + env: + PIP_DISABLE_PIP_VERSION_CHECK: 1 + steps: + - uses: actions/checkout@v4 + - name: Download wheel + uses: actions/download-artifact@v4 + with: + name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} + path: wheels/ + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: pip install auditwheel + - run: python ./.github/scripts/auditwheel_show.py wheels/* | tee $GITHUB_STEP_SUMMARY + + publish-wheels: + name: Publish wheels to PyPI + needs: [build-wheels, audit-wheels] + runs-on: ubuntu-latest + if: | + github.repository == 'bitsandbytes-foundation/bitsandbytes' + && github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + environment: + name: release + url: https://pypi.org/p/bitsandbytes + permissions: + id-token: write + steps: + - name: Download distribution artifacts + uses: actions/download-artifact@v4 + with: + path: dist/ + pattern: "bdist_wheel_*" + merge-multiple: true + + - name: Remove macOS wheels + run: rm dist/*macos* + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + print-hash: true From da9a271446295e012cd61263836ab8fea0a06af8 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 3 Jun 2025 00:06:56 +0530 Subject: [PATCH 057/102] Update python-package.yml --- .github/workflows/python-package.yml | 53 +++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index fbaa27d56..8b0bbb374 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -102,10 +102,55 @@ jobs: path: output/* retention-days: 7 - build-wheels: - needs: - - build-shared-libs - - build-shared-libs-cuda + build-shared-libs-rocm: + strategy: + matrix: + os: [ubuntu-22.04] + arch: [x86_64] + rocm_version: + ["6.1.2", "6.2.4", "6.3.2"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Docker multiarch + uses: docker/setup-qemu-action@v3 + - name: Clean up disk space + run: | + sudo rm -rf \ + /usr/share/dotnet \ + /opt/ghc \ + "/usr/local/share/boost" \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/hostedtoolcache \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/swift + - name: Build C++ + run: bash .github/scripts/build-rocm.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + rocm_version: ${{ matrix.rocm_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} + path: output/* + retention-days: 7 + + build-wheels: + needs: + - build-shared-libs + - build-shared-libs-cuda + - build-shared-libs-rocm strategy: matrix: os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] From 08848daddb2ec6bd13f7b5a0720bd6d34988d818 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 3 Jun 2025 00:12:54 +0530 Subject: [PATCH 058/102] Update python-package.yml --- .github/workflows/python-package.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 8b0bbb374..a65d0f5bb 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -145,12 +145,12 @@ jobs: name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} path: output/* retention-days: 7 - - build-wheels: - needs: - - build-shared-libs - - build-shared-libs-cuda - - build-shared-libs-rocm + + build-wheels: + needs: + - build-shared-libs + - build-shared-libs-cuda + - build-shared-libs-rocm strategy: matrix: os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] From 978cba3825e3624bc39d594a2bd01c2444e1af69 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 3 Jun 2025 01:33:00 +0530 Subject: [PATCH 059/102] Create build-rocm.sh --- .github/scripts/build-rocm.sh | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 .github/scripts/build-rocm.sh diff --git a/.github/scripts/build-rocm.sh b/.github/scripts/build-rocm.sh new file mode 100644 index 000000000..b508fac69 --- /dev/null +++ b/.github/scripts/build-rocm.sh @@ -0,0 +1,21 @@ +#!/bin/bash +declare build_arch +declare build_os +declare rocm_version + +set -xeuo pipefail +bnb_rocm_arch="gfx90a;gfx942;gfx1100" +if [ "${build_os:0:6}" == ubuntu ]; then + image=rocm/dev-ubuntu-22.04:${rocm_version}-complete + echo "Using image $image" + docker run --rm --platform "linux/$build_arch" -i \ + -w /src -v "$PWD:/src" "$image" sh -c \ + "apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ + && cmake -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \ + && cmake --build ." +fi + +output_dir="output/${build_os}/${build_arch}" +mkdir -p "${output_dir}" +(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") From af6561aec6d7df66f58d4f667e1f1307aef57011 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 4 Jun 2025 00:34:30 +0530 Subject: [PATCH 060/102] Update cuda_specs.py --- bitsandbytes/cuda_specs.py | 48 +++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 61d03083c..bbdf457cc 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,6 +1,6 @@ import dataclasses -import logging -import re +import logging +import re import subprocess from functools import lru_cache from typing import Optional @@ -78,25 +78,25 @@ def get_cuda_specs() -> Optional[CUDASpecs]: return None -def get_rocm_gpu_arch() -> str: - """Get ROCm GPU architecture.""" - logger = logging.getLogger(__name__) - try: - if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) - match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) - if match: - return "gfx" + match.group(1) - else: - return "unknown" - else: - return "unknown" - except Exception as e: - logger.error(f"Could not detect ROCm GPU architecture: {e}") - if torch.cuda.is_available(): - logger.warning( - """ -ROCm GPU architecture detection failed despite ROCm being available. - """, - ) - return "unknown" +def get_rocm_gpu_arch() -> str: + """Get ROCm GPU architecture.""" + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) + if match: + return "gfx" + match.group(1) + else: + return "unknown" + else: + return "unknown" + except Exception as e: + logger.error(f"Could not detect ROCm GPU architecture: {e}") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm GPU architecture detection failed despite ROCm being available. + """, + ) + return "unknown" From 405b4843fe2dffc0ab8059f82a4e3fb399ed10f0 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 4 Jun 2025 00:54:11 +0530 Subject: [PATCH 061/102] Fix trailing whitespace --- .github/workflows/python-package.yml | 96 +++---- bitsandbytes/backends/cuda/ops.py | 36 +-- bitsandbytes/cextension.py | 16 +- bitsandbytes/cuda_specs.py | 2 +- bitsandbytes/diagnostics/cuda.py | 12 +- bitsandbytes/diagnostics/main.py | 3 +- bitsandbytes/functional.py | 10 +- bitsandbytes/nn/modules.py | 4 +- conflicts.diff | 382 +++++++++++++++++++++++++++ csrc/common_hip.cuh | 2 +- csrc/kernels.hip | 26 +- csrc/ops.hip | 10 +- tests/test_cuda_setup_evaluator.py | 2 + tests/test_functional.py | 15 +- tests/test_linear4bit.py | 1 + tests/test_ops.py | 2 +- 16 files changed, 506 insertions(+), 113 deletions(-) create mode 100644 conflicts.diff diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index a65d0f5bb..3673ac608 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -102,49 +102,49 @@ jobs: path: output/* retention-days: 7 - build-shared-libs-rocm: - strategy: - matrix: - os: [ubuntu-22.04] - arch: [x86_64] - rocm_version: - ["6.1.2", "6.2.4", "6.3.2"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Set up Docker multiarch - uses: docker/setup-qemu-action@v3 - - name: Clean up disk space - run: | - sudo rm -rf \ - /usr/share/dotnet \ - /opt/ghc \ - "/usr/local/share/boost" \ - "$AGENT_TOOLSDIRECTORY" \ - /opt/hostedtoolcache \ - /opt/google/chrome \ - /opt/microsoft/msedge \ - /opt/microsoft/powershell \ - /opt/pipx \ - /usr/lib/mono \ - /usr/local/julia* \ - /usr/local/lib/android \ - /usr/local/lib/node_modules \ - /usr/local/share/chromium \ - /usr/local/share/powershell \ - /usr/share/swift - - name: Build C++ - run: bash .github/scripts/build-rocm.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - rocm_version: ${{ matrix.rocm_version }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} - path: output/* - retention-days: 7 + build-shared-libs-rocm: + strategy: + matrix: + os: [ubuntu-22.04] + arch: [x86_64] + rocm_version: + ["6.1.2", "6.2.4", "6.3.2"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Docker multiarch + uses: docker/setup-qemu-action@v3 + - name: Clean up disk space + run: | + sudo rm -rf \ + /usr/share/dotnet \ + /opt/ghc \ + "/usr/local/share/boost" \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/hostedtoolcache \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/swift + - name: Build C++ + run: bash .github/scripts/build-rocm.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + rocm_version: ${{ matrix.rocm_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} + path: output/* + retention-days: 7 build-wheels: needs: @@ -216,10 +216,10 @@ jobs: path: tmp/ pattern: "bdist_wheel_*" merge-multiple: true - + - name: Inspect tmp directory after downloading artifacts run: ls -alFR tmp/ - + - name: Move and rename wheel files with pattern replacement run: | mkdir -p wheels/ @@ -244,7 +244,7 @@ jobs: - name: Inspect wheels directory after renaming files run: ls -alFR wheels/ - + - name: Delete old pre-release (if exists) run: | gh release delete continuous-release_main --cleanup-tag -y || true @@ -258,7 +258,7 @@ jobs: This pre-release contains the latest development wheels for all supported platforms, rebuilt automatically on every commit to the `main` branch. - **How to install:** + **How to install:** Pick the correct command for your platform and run it in your terminal: ENDOFMARKDOWN @@ -273,7 +273,7 @@ jobs: done cat >> body.md << 'ENDOFMARKDOWN' - > **Note:** + > **Note:** > These wheels are updated automatically with every commit to `main` and become available as soon as the [python-package.yml](.github/workflows/python-package.yml) workflow finishes. ENDOFMARKDOWN diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index fd7b7b9a2..9089d6fc2 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -8,7 +8,7 @@ from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel -from ...cextension import lib, HIP_ENVIRONMENT +from ...cextension import HIP_ENVIRONMENT, lib @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -210,12 +210,12 @@ def _get_col_absmax( @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) - - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") n = A.numel() @@ -269,11 +269,11 @@ def _( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check( dtype in [torch.float16, torch.bfloat16, torch.float32], @@ -303,11 +303,11 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(quant_type in ["fp4", "nf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -385,11 +385,11 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 108aa0c9a..5283df93e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -81,7 +81,7 @@ def get_available_cuda_binary_versions() -> list[str]: lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}" versions = [] for lib in Path(__file__).parent.glob(lib_pattern): - pattern = r"{}(\d+)".format(BNB_BACKEND.lower()) + pattern = rf"{BNB_BACKEND.lower()}(\d+)" match = re.search(pattern, lib.name) if match: ver_code = int(match.group(1)) @@ -199,18 +199,16 @@ def _format_lib_error_message( ) compile_instructions = ( - ( - "COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n" - ) if not no_cuda_lib_found - else - ( + ("COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n") + if not no_cuda_lib_found + else ( "You have two options:\n" "1. COMPILE FROM SOURCE (required if no binary exists):\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation#cuda-compile\n" "2. Use BNB_CUDA_VERSION to specify a DIFFERENT CUDA version from the detected one, which is installed on your machine and matching an available pre-compiled version listed above\n\n" - ) if not HIP_ENVIRONMENT - else - ( + ) + if not HIP_ENVIRONMENT + else ( "You can COMPILE FROM SOURCE as mentioned here:\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n" ) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index bbdf457cc..32563a159 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,8 +1,8 @@ import dataclasses +from functools import lru_cache import logging import re import subprocess -from functools import lru_cache from typing import Optional import torch diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index b9de27fd7..b9db101ab 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -33,11 +33,13 @@ } CUDA_RUNTIME_LIB_PATTERNS = ( - "libamdhip64.so*", -) if HIP_ENVIRONMENT else ( - "cudart64*.dll", # Windows - "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. - "nvcuda*.dll", # Windows + ("libamdhip64.so*",) + if HIP_ENVIRONMENT + else ( + "cudart64*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. + "nvcuda*.dll", # Windows + ) ) logger = logging.getLogger(__name__) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 8e2bc2a7b..bf31d7978 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -43,7 +43,8 @@ def main(): print(f"{BNB_BACKEND} specs:{cuda_specs}") if not torch.cuda.is_available(): print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") - if not HIP_ENVIRONMENT: print(f"- {BNB_BACKEND} driver not installed") + if not HIP_ENVIRONMENT: + print(f"- {BNB_BACKEND} driver not installed") print(f"- {BNB_BACKEND} not installed") print(f"- You have multiple conflicting {BNB_BACKEND} libraries") if cuda_specs: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 03f6c323d..9b7ce2da9 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import lib, HIP_ENVIRONMENT +from .cextension import HIP_ENVIRONMENT, lib name2qmap = {} @@ -1007,10 +1007,10 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ - + if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - + input_shape = A.shape _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( @@ -1114,10 +1114,10 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ - + if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - + if quant_state is None: assert absmax is not None and out is not None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2383f2c10..a2facac28 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -222,10 +222,10 @@ def __new__( ) -> "Params4bit": if data is None: data = torch.empty(0) - + if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - + self = torch.Tensor._make_subclass(cls, data, requires_grad) self.blocksize = blocksize self.compress_statistics = compress_statistics diff --git a/conflicts.diff b/conflicts.diff new file mode 100644 index 000000000..cab8c6ea7 --- /dev/null +++ b/conflicts.diff @@ -0,0 +1,382 @@ +diff --cc bitsandbytes/cextension.py +index 108aa0c,b112df2..0000000 +--- a/bitsandbytes/cextension.py ++++ b/bitsandbytes/cextension.py +@@@ -28,17 -28,10 +29,15 @@@ def get_cuda_bnb_library_path(cuda_spec + override_value = os.environ.get("BNB_CUDA_VERSION") + if override_value: + library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) + + if torch.version.hip: + + raise RuntimeError( + + f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" + + f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" + + ) + logger.warning( + f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" +- "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" ++ "This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n" + "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" +- "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" +- "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: BNBNativeLi + return BNBNativeLibrary(dll) + + + +ROCM_GPU_ARCH = get_rocm_gpu_arch() + + + try: +++<<<<<<< HEAD + + if torch.version.hip: + + HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" + + else: + + HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" + + +++======= ++ # to support Intel CPU/GPU (XPU) backend ++ import intel_extension_for_pytorch as ipex ++ ++ ipex_cpu = ipex if ipex._C._has_cpu() else None ++ ipex_xpu = ipex if ipex._C._has_xpu() else None ++ except BaseException: ++ ipex_cpu = None ++ ipex_xpu = None ++ ++ ++ try: +++>>>>>>> upstream/main + lib = get_native_library() + except Exception as e: + error_msg = str(e) +diff --cc bitsandbytes/diagnostics/cuda.py +index b9de27f,e763ef2..0000000 +--- a/bitsandbytes/diagnostics/cuda.py ++++ b/bitsandbytes/diagnostics/cuda.py +@@@ -5,8 -5,7 +5,12 @@@ from pathlib import Pat + + import torch + +++<<<<<<< HEAD + +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path + +from bitsandbytes.consts import NONPYTORCH_DOC_URL +++======= ++ from bitsandbytes.cextension import get_cuda_bnb_library_path +++>>>>>>> upstream/main + from bitsandbytes.cuda_specs import CUDASpecs + from bitsandbytes.diagnostics.utils import print_dedented + +@@@ -146,42 -127,8 +134,38 @@@ def _print_cuda_diagnostics(cuda_specs + """, + ) + +- # TODO: +- # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) +- # (2) Multiple CUDA versions installed +- + + -def print_cuda_runtime_diagnostics() -> None: + +def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: + + print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") + + + + binary_path = get_cuda_bnb_library_path(cuda_specs) + + if not binary_path.exists(): + + print_dedented( + + f""" + + Library not found: {binary_path}. + + Maybe you need to compile it from source? If you compiled from source, check that ROCm version + + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version + + and rebuild bitsandbytes. + + """, + + ) + + + + hip_major, hip_minor = cuda_specs.cuda_version_tuple + + if (hip_major, hip_minor) < (6, 1): + + print_dedented( + + """ + + WARNING: bitsandbytes is fully supported only from ROCm 6.1. + + """, + + ) + + + + + +def print_diagnostics(cuda_specs: CUDASpecs) -> None: + + if HIP_ENVIRONMENT: + + _print_hip_diagnostics(cuda_specs) + + else: + + _print_cuda_diagnostics(cuda_specs) + + + + + +def _print_cuda_runtime_diagnostics() -> None: + cudart_paths = list(find_cudart_libraries()) + if not cudart_paths: + print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") +diff --cc bitsandbytes/diagnostics/main.py +index 8e2bc2a,aa4cb30..0000000 +--- a/bitsandbytes/diagnostics/main.py ++++ b/bitsandbytes/diagnostics/main.py +@@@ -3,12 -5,11 +5,20 @@@ import tracebac + + import torch + +++<<<<<<< HEAD + +from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT + +from bitsandbytes.consts import PACKAGE_GITHUB_URL + +from bitsandbytes.cuda_specs import get_cuda_specs + +from bitsandbytes.diagnostics.cuda import ( + + print_diagnostics, + + print_runtime_diagnostics, +++======= ++ from bitsandbytes import __version__ as bnb_version ++ from bitsandbytes.consts import PACKAGE_GITHUB_URL ++ from bitsandbytes.cuda_specs import get_cuda_specs ++ from bitsandbytes.diagnostics.cuda import ( ++ print_cuda_diagnostics, +++>>>>>>> upstream/main + ) + from bitsandbytes.diagnostics.utils import print_dedented, print_header + +@@@ -28,52 -41,77 +50,122 @@@ def sanity_check() + assert p1 != p2 + + ++ def get_package_version(name: str) -> str: ++ try: ++ version = importlib.metadata.version(name) ++ except importlib.metadata.PackageNotFoundError: ++ version = "not found" ++ return version ++ ++ ++ def show_environment(): ++ """Simple utility to print out environment information.""" ++ ++ print(f"Platform: {platform.platform()}") ++ if platform.system() == "Linux": ++ print(f" libc: {'-'.join(platform.libc_ver())}") ++ ++ print(f"Python: {platform.python_version()}") ++ ++ print(f"PyTorch: {torch.__version__}") ++ print(f" CUDA: {torch.version.cuda or 'N/A'}") ++ print(f" HIP: {torch.version.hip or 'N/A'}") ++ print(f" XPU: {getattr(torch.version, 'xpu', 'N/A') or 'N/A'}") ++ ++ print("Related packages:") ++ for pkg in _RELATED_PACKAGES: ++ version = get_package_version(pkg) ++ print(f" {pkg}: {version}") ++ ++ + def main(): +- print_header("") +- print_header("BUG REPORT INFORMATION") ++ print_header(f"bitsandbytes v{bnb_version}") ++ show_environment() + print_header("") + +- print_header("OTHER") + cuda_specs = get_cuda_specs() +++<<<<<<< HEAD + + if HIP_ENVIRONMENT: + + rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," + + rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" + + print(f"{BNB_BACKEND} specs:{rocm_specs}") + + else: + + print(f"{BNB_BACKEND} specs:{cuda_specs}") + + if not torch.cuda.is_available(): + + print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") + + if not HIP_ENVIRONMENT: print(f"- {BNB_BACKEND} driver not installed") + + print(f"- {BNB_BACKEND} not installed") + + print(f"- You have multiple conflicting {BNB_BACKEND} libraries") + + if cuda_specs: + + print_diagnostics(cuda_specs) + + print_runtime_diagnostics() + + print_header("") + + print_header("DEBUG INFO END") + + print_header("") + + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") + + try: + + sanity_check() + + print("SUCCESS!") + + print("Installation was successful!") + + return + + except RuntimeError as e: + + if "not available in CPU-only" in str(e): + + print( + + f"WARNING: {__package__} is currently running as CPU-only!\n" + + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + + f"If you think that this is so erroneously,\nplease report an issue!", + + ) + + else: + + raise e + + except Exception: + + traceback.print_exc() + + print_dedented( + + f""" + + Above we output some debug information. + + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose + + WARNING: Please be sure to sanitize sensitive info from the output before posting it. + + """, + + ) + + sys.exit(1) +++======= ++ ++ if cuda_specs: ++ print_cuda_diagnostics(cuda_specs) ++ ++ # TODO: There's a lot of noise in this; needs improvement. ++ # print_cuda_runtime_diagnostics() ++ ++ if not torch.cuda.is_available(): ++ print("PyTorch says CUDA is not available. Possible reasons:") ++ print("1. CUDA driver not installed") ++ print("2. Using a CPU-only PyTorch build") ++ print("3. No GPU detected") ++ ++ else: ++ print("Checking that the library is importable and CUDA is callable...") ++ ++ try: ++ sanity_check() ++ print("SUCCESS!") ++ return ++ except RuntimeError as e: ++ if "not available in CPU-only" in str(e): ++ print( ++ f"WARNING: {__package__} is currently running as CPU-only!\n" ++ "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" ++ f"If you think that this is so erroneously,\nplease report an issue!", ++ ) ++ else: ++ raise e ++ except Exception: ++ traceback.print_exc() ++ ++ print_dedented( ++ f""" ++ Above we output some debug information. ++ Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ++ WARNING: Please be sure to sanitize sensitive info from the output before posting it. ++ """, ++ ) ++ sys.exit(1) +++>>>>>>> upstream/main +diff --cc bitsandbytes/functional.py +index 03f6c32,ffb6668..0000000 +mode 100644,100755..100755 +--- a/bitsandbytes/functional.py ++++ b/bitsandbytes/functional.py +@@@ -13,9 -13,9 +13,13 @@@ import torc + from torch import Tensor + from typing_extensions import deprecated + +- from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict ++ from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict + +++<<<<<<< HEAD + +from .cextension import lib, HIP_ENVIRONMENT +++======= ++ from .cextension import ipex_cpu, ipex_xpu, lib +++>>>>>>> upstream/main + + name2qmap = {} + +diff --cc bitsandbytes/nn/modules.py +index 2383f2c,ccd842c..0000000 +--- a/bitsandbytes/nn/modules.py ++++ b/bitsandbytes/nn/modules.py +@@@ -11,8 -11,7 +11,12 @@@ from torch import Tensor, device, dtype + import torch.nn.functional as F + + import bitsandbytes as bnb +++<<<<<<< HEAD + +from bitsandbytes.cextension import HIP_ENVIRONMENT + +from bitsandbytes.functional import QuantState +++======= ++ from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu +++>>>>>>> upstream/main + from bitsandbytes.optim import GlobalOptimManager + from bitsandbytes.utils import ( + INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, +diff --cc tests/test_linear4bit.py +index 1b7a772,b5db2eb..0000000 +--- a/tests/test_linear4bit.py ++++ b/tests/test_linear4bit.py +@@@ -7,8 -8,14 +8,19 @@@ import pytes + import torch + + import bitsandbytes as bnb +++<<<<<<< HEAD + +from bitsandbytes.cextension import HIP_ENVIRONMENT + +from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer +++======= ++ from tests.helpers import ( ++ TRUE_FALSE, ++ describe_dtype, ++ get_available_devices, ++ id_formatter, ++ torch_load_from_buffer, ++ torch_save_to_buffer, ++ ) +++>>>>>>> upstream/main + + storage = { + "uint8": torch.uint8, +@@@ -183,16 -185,10 +189,10 @@@ def test_linear_serialization(device, q + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) + -@pytest.mark.parametrize("blocksize", [64, 128]) + +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) + @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) + def test_copy_param(device, quant_type, blocksize, compress_statistics): +- if device == "cpu": +- if compress_statistics: +- pytest.skip("Currently segfaults on CPU") +- if quant_type == "fp4": +- pytest.xfail("FP4 not supported on CPU") +- +- tensor = torch.linspace(1, blocksize, blocksize) ++ tensor = torch.randn(300, 400) + param = bnb.nn.Params4bit( + data=tensor, + quant_type=quant_type, +@@@ -208,16 -204,10 +208,10 @@@ + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) + -@pytest.mark.parametrize("blocksize", [64, 128]) + +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) + @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) + def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): +- if device == "cpu": +- if compress_statistics: +- pytest.skip("Currently segfaults on CPU") +- if quant_type == "fp4": +- pytest.xfail("FP4 not supported on CPU") +- +- tensor = torch.linspace(1, blocksize, blocksize) ++ tensor = torch.randn(300, 400) + param = bnb.nn.Params4bit( + data=tensor, + quant_type=quant_type, +@@@ -240,16 -230,10 +234,10 @@@ + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) + -@pytest.mark.parametrize("blocksize", [64, 128]) + +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) + @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) + def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): +- if device == "cpu": +- if compress_statistics: +- pytest.skip("Currently segfaults on CPU") +- if quant_type == "fp4": +- pytest.xfail("FP4 not supported on CPU") +- +- original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32) ++ original_tensor = torch.randn(300, 400) + original_param = bnb.nn.Params4bit( + data=original_tensor, + quant_type=quant_type, diff --git a/csrc/common_hip.cuh b/csrc/common_hip.cuh index e7fc4eb81..105179535 100644 --- a/csrc/common_hip.cuh +++ b/csrc/common_hip.cuh @@ -1,6 +1,6 @@ #pragma once -#define BNB_WARP_SIZE warpSize +#define BNB_WARP_SIZE warpSize // These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs #define BNB_MAX_THREADS_PER_SM 2048 diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 368788f39..56e1d54db 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -532,7 +532,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float absmax[i / BLOCK_SIZE] = local_abs_max; } __syncthreads(); - + local_abs_max = smem_absmax_value[0]; if(STOCHASTIC) @@ -610,7 +610,7 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs valid_items_load = min(TILE_SIZE, n - i); valid_items_store = valid_items_load; } - + // Since blocksize will always be a power-of-2, we avoid more expensive // division by the blocksize and instead use a shift operation. // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. @@ -811,7 +811,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p, LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); __syncthreads(); Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - + // Load additional state1 data for AdEMAMix // TODO: Make constexpr after updating min compiler if (OPTIMIZER == ADEMAMIX) { @@ -1607,7 +1607,7 @@ kOptimizerStatic8bit2StateBlockwise( unsigned char c1s[N_PER_TH]; unsigned char c2s[N_PER_TH]; unsigned char c3s[N_PER_TH]; - + T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; typedef hipcub::BlockLoad LoadT; @@ -1712,7 +1712,7 @@ kOptimizerStatic8bit2StateBlockwise( new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); - + if (OPTIMIZER == ADEMAMIX) { new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); } @@ -1776,7 +1776,7 @@ kOptimizerStatic8bit2StateBlockwise( } else { p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); } - + if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } @@ -2148,27 +2148,27 @@ __global__ void kdequant_mm_int32_fp16( int local_values[ITEMS_PER_THREAD]; half local_output[ITEMS_PER_THREAD]; - + float local_rowStats[ITEMS_PER_THREAD]; float local_colStats[ITEMS_PER_THREAD]; float local_biasValue[ITEMS_PER_THREAD]; typedef hipcub::BlockLoad LoadInt32; __shared__ typename LoadInt32::TempStorage loadint32; - + int row_idx, col_idx; - + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) { row_idx = (block_offset + thread_offset + j) / numCols; col_idx = (block_offset + thread_offset + j) % numCols; - + local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; - local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; local_biasValue[j] = ((bias == nullptr) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); } - + // Each block loads THREADS * ITEMS_PER_THREAD values from A int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD @@ -2188,7 +2188,7 @@ __global__ void kdequant_mm_int32_fp16( if (outIdx < n_out) { out[outIdx] = local_output[j]; } - } + } } #define DENORM 1.0f/127.0f diff --git a/csrc/ops.hip b/csrc/ops.hip index 4d077d19a..eef616d48 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -199,10 +199,10 @@ template void optimizerStatic8bit(T* p, T* g, } } -#define BLOCKSIZE_2STATE 256 -#define NUM_2STATE 1 -#define BLOCKSIZE_1STATE 256 -#define NUM_1STATE 1 +#define BLOCKSIZE_2STATE 256 +#define NUM_2STATE 1 +#define BLOCKSIZE_1STATE 256 +#define NUM_1STATE 1 template void optimizerStatic8bitBlockwise( T* p, @@ -443,7 +443,7 @@ static std::string hipError_to_string(const hipError_t ret) } template int igemmlt( - hipblasLtHandle_t ltHandle, + hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 1b2ea85db..3d8b688ee 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -12,11 +12,13 @@ def cuda120_spec() -> CUDASpecs: cuda_version_tuple=(12, 0), ) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") diff --git a/tests/test_functional.py b/tests/test_functional.py index 5f5ee488c..a2964c733 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -8,8 +8,8 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH from bitsandbytes import functional as F +from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -92,7 +92,10 @@ class Test8BitBlockwiseQuantizeFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) - @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128] ) + @pytest.mark.parametrize( + "blocksize", + [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128], + ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): iters = 100 @@ -796,6 +799,7 @@ def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): A[:, outlier_cols] = 0 torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") class TestSpMMFunctional: @@ -1106,7 +1110,10 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize( + "blocksize", + [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], + ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1205,7 +1212,7 @@ def test_bench_4bit_dequant(self, quant_type): # torch.matmul(b, a.t()) # torch.cuda.synchronize() # print((time.time()-t0)/iters*1e6) - + @pytest.mark.skipif( HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" ) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 1b7a7722c..60c163477 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -17,6 +17,7 @@ "float32": torch.float32, } + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) diff --git a/tests/test_ops.py b/tests/test_ops.py index a99d080b3..a433a0c4b 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,8 +4,8 @@ import torch import bitsandbytes -from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter class TestLLMInt8Ops: From 93768d07b1b753790a784f1472e5b6b1f9fa5c73 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 4 Jun 2025 01:24:09 +0530 Subject: [PATCH 062/102] Remove conflicts.diff --- conflicts.diff | 382 ------------------------------------------------- 1 file changed, 382 deletions(-) delete mode 100644 conflicts.diff diff --git a/conflicts.diff b/conflicts.diff deleted file mode 100644 index cab8c6ea7..000000000 --- a/conflicts.diff +++ /dev/null @@ -1,382 +0,0 @@ -diff --cc bitsandbytes/cextension.py -index 108aa0c,b112df2..0000000 ---- a/bitsandbytes/cextension.py -+++ b/bitsandbytes/cextension.py -@@@ -28,17 -28,10 +29,15 @@@ def get_cuda_bnb_library_path(cuda_spec - override_value = os.environ.get("BNB_CUDA_VERSION") - if override_value: - library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) - + if torch.version.hip: - + raise RuntimeError( - + f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" - + f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" - + ) - logger.warning( - f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" -- "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" -+ "This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n" - "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" -- "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" -- "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: BNBNativeLi - return BNBNativeLibrary(dll) - - - +ROCM_GPU_ARCH = get_rocm_gpu_arch() - + - try: -++<<<<<<< HEAD - + if torch.version.hip: - + HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" - + else: - + HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" - + -++======= -+ # to support Intel CPU/GPU (XPU) backend -+ import intel_extension_for_pytorch as ipex -+ -+ ipex_cpu = ipex if ipex._C._has_cpu() else None -+ ipex_xpu = ipex if ipex._C._has_xpu() else None -+ except BaseException: -+ ipex_cpu = None -+ ipex_xpu = None -+ -+ -+ try: -++>>>>>>> upstream/main - lib = get_native_library() - except Exception as e: - error_msg = str(e) -diff --cc bitsandbytes/diagnostics/cuda.py -index b9de27f,e763ef2..0000000 ---- a/bitsandbytes/diagnostics/cuda.py -+++ b/bitsandbytes/diagnostics/cuda.py -@@@ -5,8 -5,7 +5,12 @@@ from pathlib import Pat - - import torch - -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path - +from bitsandbytes.consts import NONPYTORCH_DOC_URL -++======= -+ from bitsandbytes.cextension import get_cuda_bnb_library_path -++>>>>>>> upstream/main - from bitsandbytes.cuda_specs import CUDASpecs - from bitsandbytes.diagnostics.utils import print_dedented - -@@@ -146,42 -127,8 +134,38 @@@ def _print_cuda_diagnostics(cuda_specs - """, - ) - -- # TODO: -- # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) -- # (2) Multiple CUDA versions installed -- - - -def print_cuda_runtime_diagnostics() -> None: - +def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: - + print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") - + - + binary_path = get_cuda_bnb_library_path(cuda_specs) - + if not binary_path.exists(): - + print_dedented( - + f""" - + Library not found: {binary_path}. - + Maybe you need to compile it from source? If you compiled from source, check that ROCm version - + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version - + and rebuild bitsandbytes. - + """, - + ) - + - + hip_major, hip_minor = cuda_specs.cuda_version_tuple - + if (hip_major, hip_minor) < (6, 1): - + print_dedented( - + """ - + WARNING: bitsandbytes is fully supported only from ROCm 6.1. - + """, - + ) - + - + - +def print_diagnostics(cuda_specs: CUDASpecs) -> None: - + if HIP_ENVIRONMENT: - + _print_hip_diagnostics(cuda_specs) - + else: - + _print_cuda_diagnostics(cuda_specs) - + - + - +def _print_cuda_runtime_diagnostics() -> None: - cudart_paths = list(find_cudart_libraries()) - if not cudart_paths: - print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") -diff --cc bitsandbytes/diagnostics/main.py -index 8e2bc2a,aa4cb30..0000000 ---- a/bitsandbytes/diagnostics/main.py -+++ b/bitsandbytes/diagnostics/main.py -@@@ -3,12 -5,11 +5,20 @@@ import tracebac - - import torch - -++<<<<<<< HEAD - +from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT - +from bitsandbytes.consts import PACKAGE_GITHUB_URL - +from bitsandbytes.cuda_specs import get_cuda_specs - +from bitsandbytes.diagnostics.cuda import ( - + print_diagnostics, - + print_runtime_diagnostics, -++======= -+ from bitsandbytes import __version__ as bnb_version -+ from bitsandbytes.consts import PACKAGE_GITHUB_URL -+ from bitsandbytes.cuda_specs import get_cuda_specs -+ from bitsandbytes.diagnostics.cuda import ( -+ print_cuda_diagnostics, -++>>>>>>> upstream/main - ) - from bitsandbytes.diagnostics.utils import print_dedented, print_header - -@@@ -28,52 -41,77 +50,122 @@@ def sanity_check() - assert p1 != p2 - - -+ def get_package_version(name: str) -> str: -+ try: -+ version = importlib.metadata.version(name) -+ except importlib.metadata.PackageNotFoundError: -+ version = "not found" -+ return version -+ -+ -+ def show_environment(): -+ """Simple utility to print out environment information.""" -+ -+ print(f"Platform: {platform.platform()}") -+ if platform.system() == "Linux": -+ print(f" libc: {'-'.join(platform.libc_ver())}") -+ -+ print(f"Python: {platform.python_version()}") -+ -+ print(f"PyTorch: {torch.__version__}") -+ print(f" CUDA: {torch.version.cuda or 'N/A'}") -+ print(f" HIP: {torch.version.hip or 'N/A'}") -+ print(f" XPU: {getattr(torch.version, 'xpu', 'N/A') or 'N/A'}") -+ -+ print("Related packages:") -+ for pkg in _RELATED_PACKAGES: -+ version = get_package_version(pkg) -+ print(f" {pkg}: {version}") -+ -+ - def main(): -- print_header("") -- print_header("BUG REPORT INFORMATION") -+ print_header(f"bitsandbytes v{bnb_version}") -+ show_environment() - print_header("") - -- print_header("OTHER") - cuda_specs = get_cuda_specs() -++<<<<<<< HEAD - + if HIP_ENVIRONMENT: - + rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," - + rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" - + print(f"{BNB_BACKEND} specs:{rocm_specs}") - + else: - + print(f"{BNB_BACKEND} specs:{cuda_specs}") - + if not torch.cuda.is_available(): - + print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") - + if not HIP_ENVIRONMENT: print(f"- {BNB_BACKEND} driver not installed") - + print(f"- {BNB_BACKEND} not installed") - + print(f"- You have multiple conflicting {BNB_BACKEND} libraries") - + if cuda_specs: - + print_diagnostics(cuda_specs) - + print_runtime_diagnostics() - + print_header("") - + print_header("DEBUG INFO END") - + print_header("") - + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") - + try: - + sanity_check() - + print("SUCCESS!") - + print("Installation was successful!") - + return - + except RuntimeError as e: - + if "not available in CPU-only" in str(e): - + print( - + f"WARNING: {__package__} is currently running as CPU-only!\n" - + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - + f"If you think that this is so erroneously,\nplease report an issue!", - + ) - + else: - + raise e - + except Exception: - + traceback.print_exc() - + print_dedented( - + f""" - + Above we output some debug information. - + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose - + WARNING: Please be sure to sanitize sensitive info from the output before posting it. - + """, - + ) - + sys.exit(1) -++======= -+ -+ if cuda_specs: -+ print_cuda_diagnostics(cuda_specs) -+ -+ # TODO: There's a lot of noise in this; needs improvement. -+ # print_cuda_runtime_diagnostics() -+ -+ if not torch.cuda.is_available(): -+ print("PyTorch says CUDA is not available. Possible reasons:") -+ print("1. CUDA driver not installed") -+ print("2. Using a CPU-only PyTorch build") -+ print("3. No GPU detected") -+ -+ else: -+ print("Checking that the library is importable and CUDA is callable...") -+ -+ try: -+ sanity_check() -+ print("SUCCESS!") -+ return -+ except RuntimeError as e: -+ if "not available in CPU-only" in str(e): -+ print( -+ f"WARNING: {__package__} is currently running as CPU-only!\n" -+ "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" -+ f"If you think that this is so erroneously,\nplease report an issue!", -+ ) -+ else: -+ raise e -+ except Exception: -+ traceback.print_exc() -+ -+ print_dedented( -+ f""" -+ Above we output some debug information. -+ Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose -+ WARNING: Please be sure to sanitize sensitive info from the output before posting it. -+ """, -+ ) -+ sys.exit(1) -++>>>>>>> upstream/main -diff --cc bitsandbytes/functional.py -index 03f6c32,ffb6668..0000000 -mode 100644,100755..100755 ---- a/bitsandbytes/functional.py -+++ b/bitsandbytes/functional.py -@@@ -13,9 -13,9 +13,13 @@@ import torc - from torch import Tensor - from typing_extensions import deprecated - -- from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -+ from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict - -++<<<<<<< HEAD - +from .cextension import lib, HIP_ENVIRONMENT -++======= -+ from .cextension import ipex_cpu, ipex_xpu, lib -++>>>>>>> upstream/main - - name2qmap = {} - -diff --cc bitsandbytes/nn/modules.py -index 2383f2c,ccd842c..0000000 ---- a/bitsandbytes/nn/modules.py -+++ b/bitsandbytes/nn/modules.py -@@@ -11,8 -11,7 +11,12 @@@ from torch import Tensor, device, dtype - import torch.nn.functional as F - - import bitsandbytes as bnb -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT - +from bitsandbytes.functional import QuantState -++======= -+ from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu -++>>>>>>> upstream/main - from bitsandbytes.optim import GlobalOptimManager - from bitsandbytes.utils import ( - INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, -diff --cc tests/test_linear4bit.py -index 1b7a772,b5db2eb..0000000 ---- a/tests/test_linear4bit.py -+++ b/tests/test_linear4bit.py -@@@ -7,8 -8,14 +8,19 @@@ import pytes - import torch - - import bitsandbytes as bnb -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT - +from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer -++======= -+ from tests.helpers import ( -+ TRUE_FALSE, -+ describe_dtype, -+ get_available_devices, -+ id_formatter, -+ torch_load_from_buffer, -+ torch_save_to_buffer, -+ ) -++>>>>>>> upstream/main - - storage = { - "uint8": torch.uint8, -@@@ -183,16 -185,10 +189,10 @@@ def test_linear_serialization(device, q - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_copy_param(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- tensor = torch.linspace(1, blocksize, blocksize) -+ tensor = torch.randn(300, 400) - param = bnb.nn.Params4bit( - data=tensor, - quant_type=quant_type, -@@@ -208,16 -204,10 +208,10 @@@ - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- tensor = torch.linspace(1, blocksize, blocksize) -+ tensor = torch.randn(300, 400) - param = bnb.nn.Params4bit( - data=tensor, - quant_type=quant_type, -@@@ -240,16 -230,10 +234,10 @@@ - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32) -+ original_tensor = torch.randn(300, 400) - original_param = bnb.nn.Params4bit( - data=original_tensor, - quant_type=quant_type, From e119ff73efa8aa4d48c651e2d762e5107631f22d Mon Sep 17 00:00:00 2001 From: amcamd Date: Thu, 5 Jun 2025 17:13:30 -0400 Subject: [PATCH 063/102] update for hipblasVersionMajor >=3 --- csrc/ops.hip | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/csrc/ops.hip b/csrc/ops.hip index eef616d48..a9c3e0202 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -269,6 +269,15 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in const void * beta = &fbeta; hipblasStatus_t status; +#if hipblasVersionMajor >= 3 + status = hipblasGemmEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta, + C, HIP_R_32I, ldc, + HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT); +#else status = hipblasGemmEx(context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, @@ -276,6 +285,7 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, C, HIPBLAS_R_32I, ldc, HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); +#endif if (status != HIPBLAS_STATUS_SUCCESS) { @@ -299,6 +309,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i //printf("%i %i %i\n", strideA, strideB, strideC); //printf("%i\n", batchCount); +#if hipblasVersionMajor >= 3 + status = hipblasGemmStridedBatchedEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIP_R_8I, lda, (long long int)strideA, B, HIP_R_8I, ldb, (long long int)strideB, beta, + C, HIP_R_32I, ldc, (long long int)strideC, batchCount, + HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT); +#else status = hipblasGemmStridedBatchedEx(context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, @@ -306,6 +325,7 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); +#endif if (status != HIPBLAS_STATUS_SUCCESS) { From 8dc297d32adf90a079decd0a8649736dc5258089 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 6 Jun 2025 23:40:46 +0530 Subject: [PATCH 064/102] Update test_functional.py --- tests/test_functional.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_functional.py b/tests/test_functional.py index a2964c733..95f75d99f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -98,6 +98,9 @@ class Test8BitBlockwiseQuantizeFunctional: ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + iters = 100 if device == "cpu": @@ -150,6 +153,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, assert A2.dtype == dtype @pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required") + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="CPU tests skipped when HIP_ENVIRONMENT is set") @pytest.mark.parametrize("hidden", [128]) @pytest.mark.parametrize("blocksize", [4096, 16384]) def test_blockwise_cpu_large(self, hidden, blocksize): @@ -176,6 +180,9 @@ def test_blockwise_cpu_large(self, hidden, blocksize): @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) def test_few_bit_quant(self, device, bits, method): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and bits != 8: pytest.skip("CPU implementation only supports 8 bits") @@ -232,6 +239,9 @@ def test_few_bit_quant(self, device, bits, method): @pytest.mark.parametrize("device", get_available_devices()) def test_fp8_quant(self, device): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + # TODO if device == "cpu": pytest.skip("CPU implementation segfaults") @@ -570,6 +580,9 @@ class TestLLMInt8Functional: @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + for i in range(k): if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim3), dtype=torch.int8, device=device) @@ -588,6 +601,9 @@ def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): @pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + for i in range(k): if dims == 2: A = torch.normal(0, 0.5, size=(dim1, dim3), device=device).half() @@ -611,6 +627,9 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) def test_dequant_mm(self, device, dim1, dim4, dims, has_bias): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + inner = 128 bias = None if has_bias: @@ -734,6 +753,9 @@ def test_int8_double_quant(self, dim1, dim2): ), ) def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and inner > 2048: pytest.skip("Slow on CPU") @@ -767,6 +789,9 @@ def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_double_quant(self, device, dim1, dim2): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + threshold = 2.00 for i in range(k): A = torch.randn(dim1, dim2, device=device).half() @@ -787,6 +812,9 @@ def test_coo_double_quant(self, device, dim1, dim2): @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + threshold = 3.00 for i in range(k): A = torch.randn(dim1, dim2, device=device).half() @@ -1115,6 +1143,9 @@ class TestQuantize4BitFunctional: [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1150,6 +1181,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) def test_4bit_compressed_stats(self, device, quant_type, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1228,6 +1262,9 @@ def test_bench_4bit_dequant(self, quant_type): ) @pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim")) def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if storage_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1384,6 +1421,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double reason="this test is not supported on ROCm with gfx90a architecture yet", ) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and storage_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") From f7d8bf340bb9d36c3412cbbedc564c2edecc8308 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 6 Jun 2025 23:45:28 +0530 Subject: [PATCH 065/102] Update test_linear4bit.py --- tests/test_linear4bit.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 60c163477..546ed2681 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -25,6 +25,9 @@ @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if quant_type == "fp4": pytest.xfail("FP4 is not supported for CPU") @@ -187,6 +190,9 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") @@ -212,6 +218,9 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") @@ -244,6 +253,9 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") From fd0a4d0fc4dc610fcf96a0469b41a68299d6daa5 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 6 Jun 2025 23:52:40 +0530 Subject: [PATCH 066/102] Update test_ops.py --- tests/test_ops.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_ops.py b/tests/test_ops.py index a433a0c4b..3879aa479 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -11,6 +11,9 @@ class TestLLMInt8Ops: @pytest.mark.parametrize("device", get_available_devices()) def test_int8_linear_matmul(self, device): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B) @@ -23,6 +26,9 @@ def test_int8_linear_matmul(self, device): @pytest.mark.parametrize("device", get_available_devices()) def test_int8_linear_matmul_out(self, device): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -38,6 +44,9 @@ def test_int8_linear_matmul_out(self, device): @pytest.mark.parametrize("threshold", [0.0, 6.0]) @pytest.mark.parametrize("device", get_available_devices()) def test_int8_vectorwise_quant(self, threshold, device): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + A = torch.randn(10, 20, dtype=torch.float16, device=device) A[1][0] = 1000.0 @@ -64,6 +73,9 @@ def test_int8_vectorwise_quant(self, threshold, device): @pytest.mark.parametrize("device", get_available_devices()) def test_int8_mm_dequant(self, device): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) row_stats = torch.randn(256, dtype=torch.float32, device=device) col_stats = torch.randn(256, dtype=torch.float32, device=device) @@ -79,6 +91,9 @@ def test_int8_mm_dequant(self, device): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("has_bias", TRUE_FALSE) def test_int8_scaled_mm(self, device, dtype, has_bias): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) row_stats = torch.randn(10, dtype=torch.float32, device=device) @@ -98,6 +113,9 @@ class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -122,6 +140,9 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -148,6 +169,9 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -172,6 +196,9 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -209,6 +236,9 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": pytest.xfail("CPU implementation is not available") From 75487d38f59e6b3c6e05182ecc42330275f488f6 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Sat, 7 Jun 2025 00:58:16 +0530 Subject: [PATCH 067/102] Update main.py --- bitsandbytes/diagnostics/main.py | 33 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 7cd04e209..ed7999f14 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -77,29 +77,25 @@ def main(): print_header("") cuda_specs = get_cuda_specs() - if HIP_ENVIRONMENT: - rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," - rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" - print(f"{BNB_BACKEND} specs:{rocm_specs}") - else: - print(f"{BNB_BACKEND} specs:{cuda_specs}") - if not torch.cuda.is_available(): - print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") - if not HIP_ENVIRONMENT: - print(f"- {BNB_BACKEND} driver not installed") - print(f"- {BNB_BACKEND} not installed") - print(f"- You have multiple conflicting {BNB_BACKEND} libraries") + if cuda_specs: print_diagnostics(cuda_specs) - print_runtime_diagnostics() - print_header("") - print_header("DEBUG INFO END") - print_header("") - print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") + + # TODO: There's a lot of noise in this; needs improvement. + # print_cuda_runtime_diagnostics() + + if not torch.cuda.is_available(): + print(f"PyTorch says {BNB_BACKEND} is not available. Possible reasons:") + print(f"1. {BNB_BACKEND} driver not installed") + print(f"2. Using a CPU-only PyTorch build") + print(f"3. No GPU detected") + + else: + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") + try: sanity_check() print("SUCCESS!") - print("Installation was successful!") return except RuntimeError as e: if "not available in CPU-only" in str(e): @@ -112,6 +108,7 @@ def main(): raise e except Exception: traceback.print_exc() + print_dedented( f""" Above we output some debug information. From 3551457f987e834999b39f6df01868587e3233e3 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 02:32:43 +0530 Subject: [PATCH 068/102] Update test_functional.py --- tests/test_functional.py | 105 +++++++++++++++++++-------------------- 1 file changed, 52 insertions(+), 53 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 95f75d99f..719f21137 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -89,7 +89,10 @@ def reset(self): class Test8BitBlockwiseQuantizeFunctional: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize( @@ -98,9 +101,6 @@ class Test8BitBlockwiseQuantizeFunctional: ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - iters = 100 if device == "cpu": @@ -153,7 +153,6 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, assert A2.dtype == dtype @pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required") - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="CPU tests skipped when HIP_ENVIRONMENT is set") @pytest.mark.parametrize("hidden", [128]) @pytest.mark.parametrize("blocksize", [4096, 16384]) def test_blockwise_cpu_large(self, hidden, blocksize): @@ -176,13 +175,13 @@ def test_blockwise_cpu_large(self, hidden, blocksize): # print(sum(diffs)/len(diffs)) # print(sum(reldiffs)/len(reldiffs)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) def test_few_bit_quant(self, device, bits, method): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and bits != 8: pytest.skip("CPU implementation only supports 8 bits") @@ -237,11 +236,11 @@ def test_few_bit_quant(self, device, bits, method): else: torch.testing.assert_close(q1, q2) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_fp8_quant(self, device): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - # TODO if device == "cpu": pytest.skip("CPU implementation segfaults") @@ -572,7 +571,10 @@ def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): class TestLLMInt8Functional: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3")) @@ -580,9 +582,6 @@ class TestLLMInt8Functional: @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - for i in range(k): if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim3), dtype=torch.int8, device=device) @@ -594,16 +593,16 @@ def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): C2 = F.int8_linear_matmul(A, B) torch.testing.assert_close(C1, C2.float()) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - for i in range(k): if dims == 2: A = torch.normal(0, 0.5, size=(dim1, dim3), device=device).half() @@ -621,15 +620,15 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) def test_dequant_mm(self, device, dim1, dim4, dims, has_bias): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - inner = 128 bias = None if has_bias: @@ -740,7 +739,10 @@ def test_int8_double_quant(self, dim1, dim2): torch.testing.assert_close(Srow.flatten().float(), statsA) torch.testing.assert_close(Scol.flatten().float(), statsAt) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize( ("dim1", "dim4", "inner"), ( @@ -753,9 +755,6 @@ def test_int8_double_quant(self, dim1, dim2): ), ) def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and inner > 2048: pytest.skip("Slow on CPU") @@ -785,13 +784,13 @@ def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): err2 = torch.abs(out1 - out3).mean().item() assert err2 <= err1 * 1.025 - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_double_quant(self, device, dim1, dim2): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - threshold = 2.00 for i in range(k): A = torch.randn(dim1, dim2, device=device).half() @@ -808,13 +807,13 @@ def test_coo_double_quant(self, device, dim1, dim2): A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - threshold = 3.00 for i in range(k): A = torch.randn(dim1, dim2, device=device).half() @@ -1135,7 +1134,10 @@ def test_coo2csc(self): class TestQuantize4BitFunctional: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize( @@ -1143,9 +1145,6 @@ class TestQuantize4BitFunctional: [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1177,13 +1176,13 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 assert err.item() < math.log2(blocksize) * 8e-2 - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) def test_4bit_compressed_stats(self, device, quant_type, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1250,7 +1249,10 @@ def test_bench_4bit_dequant(self, quant_type): @pytest.mark.skipif( HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" ) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) @@ -1262,9 +1264,6 @@ def test_bench_4bit_dequant(self, quant_type): ) @pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim")) def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if storage_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1412,7 +1411,10 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert relratio < 1.04 and relratio > 0.96 assert maxratio < 1.02 and maxratio > 0.98 - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) @@ -1421,9 +1423,6 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double reason="this test is not supported on ROCm with gfx90a architecture yet", ) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and storage_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") From 90437b94837529b7519e59c64a5d5774090fba80 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 02:44:33 +0530 Subject: [PATCH 069/102] Update test_linear4bit.py --- tests/test_linear4bit.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 546ed2681..fe3f4b13c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -18,7 +18,10 @@ } -@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -185,7 +188,10 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua assert size_ratio < target_compression, ratio_error_msg -@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -213,7 +219,10 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() -@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -248,7 +257,10 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): assert dict_keys_before == dict_keys_copy -@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) From a0bdc94db673238c0b3e12ff9ac03117f5f966f2 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 02:51:33 +0530 Subject: [PATCH 070/102] Update test_ops.py --- tests/test_ops.py | 80 +++++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 3879aa479..e3be5fd50 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -9,11 +9,11 @@ class TestLLMInt8Ops: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_linear_matmul(self, device): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B) @@ -24,11 +24,11 @@ def test_int8_linear_matmul(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_linear_matmul_out(self, device): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -42,11 +42,11 @@ def test_int8_linear_matmul_out(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out)) @pytest.mark.parametrize("threshold", [0.0, 6.0]) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_vectorwise_quant(self, threshold, device): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - A = torch.randn(10, 20, dtype=torch.float16, device=device) A[1][0] = 1000.0 @@ -71,11 +71,11 @@ def test_int8_vectorwise_quant(self, threshold, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_mm_dequant(self, device): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) row_stats = torch.randn(256, dtype=torch.float32, device=device) col_stats = torch.randn(256, dtype=torch.float32, device=device) @@ -87,13 +87,13 @@ def test_int8_mm_dequant(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("has_bias", TRUE_FALSE) def test_int8_scaled_mm(self, device, dtype, has_bias): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) row_stats = torch.randn(10, dtype=torch.float32, device=device) @@ -109,13 +109,13 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -136,13 +136,13 @@ def test_quantize_blockwise(self, device, dtype, blocksize): torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -163,15 +163,15 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): class Test4bitBlockwiseQuantOps: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -190,15 +190,15 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -230,15 +230,15 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype) ) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": pytest.xfail("CPU implementation is not available") From 8a27346f8fd6ecf8eea4127ea13899618d9a921c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 02:56:08 +0530 Subject: [PATCH 071/102] Update test_linear4bit.py --- tests/test_linear4bit.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index fe3f4b13c..760e4b8c9 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -21,16 +21,13 @@ @pytest.mark.parametrize( "device", [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if quant_type == "fp4": pytest.xfail("FP4 is not supported for CPU") @@ -191,14 +188,11 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua @pytest.mark.parametrize( "device", [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") @@ -222,14 +216,11 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize( "device", [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") @@ -260,14 +251,11 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize( "device", [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") From c945dbb5c8b14bf54631d39f51b6c1d841981043 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 11 Jun 2025 03:05:38 +0530 Subject: [PATCH 072/102] Lint --- tests/test_functional.py | 70 ++++++++++++++++++------------------ tests/test_linear4bit.py | 26 +++++++------- tests/test_ops.py | 78 ++++++++++++++++++++-------------------- 3 files changed, 87 insertions(+), 87 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 719f21137..571eea55f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -236,9 +236,9 @@ def test_few_bit_quant(self, device, bits, method): else: torch.testing.assert_close(q1, q2) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) def test_fp8_quant(self, device): # TODO @@ -571,9 +571,9 @@ def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): class TestLLMInt8Functional: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @@ -593,9 +593,9 @@ def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): C2 = F.int8_linear_matmul(A, B) torch.testing.assert_close(C1, C2.float()) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) @@ -620,9 +620,9 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) @@ -739,10 +739,10 @@ def test_int8_double_quant(self, dim1, dim2): torch.testing.assert_close(Srow.flatten().float(), statsA) torch.testing.assert_close(Scol.flatten().float(), statsAt) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize( ("dim1", "dim4", "inner"), ( @@ -784,10 +784,10 @@ def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): err2 = torch.abs(out1 - out3).mean().item() assert err2 <= err1 * 1.025 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_double_quant(self, device, dim1, dim2): @@ -807,9 +807,9 @@ def test_coo_double_quant(self, device, dim1, dim2): A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) @@ -1134,9 +1134,9 @@ def test_coo2csc(self): class TestQuantize4BitFunctional: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -1176,9 +1176,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 assert err.item() < math.log2(blocksize) * 8e-2 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) @@ -1249,9 +1249,9 @@ def test_bench_4bit_dequant(self, quant_type): @pytest.mark.skipif( HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" ) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @@ -1411,9 +1411,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert relratio < 1.04 and relratio > 0.96 assert maxratio < 1.02 and maxratio > 0.98 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 760e4b8c9..ddc609616 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -18,9 +18,9 @@ } -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) @@ -185,10 +185,10 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua assert size_ratio < target_compression, ratio_error_msg -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -213,9 +213,9 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @@ -248,9 +248,9 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): assert dict_keys_before == dict_keys_copy -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) diff --git a/tests/test_ops.py b/tests/test_ops.py index e3be5fd50..9d406b793 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -9,10 +9,10 @@ class TestLLMInt8Ops: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_linear_matmul(self, device): A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -24,10 +24,10 @@ def test_int8_linear_matmul(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_linear_matmul_out(self, device): A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -42,10 +42,10 @@ def test_int8_linear_matmul_out(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out)) @pytest.mark.parametrize("threshold", [0.0, 6.0]) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_vectorwise_quant(self, threshold, device): A = torch.randn(10, 20, dtype=torch.float16, device=device) A[1][0] = 1000.0 @@ -71,10 +71,10 @@ def test_int8_vectorwise_quant(self, threshold, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_mm_dequant(self, device): A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) row_stats = torch.randn(256, dtype=torch.float32, device=device) @@ -87,9 +87,9 @@ def test_int8_mm_dequant(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("has_bias", TRUE_FALSE) @@ -109,10 +109,10 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): @@ -136,10 +136,10 @@ def test_quantize_blockwise(self, device, dtype, blocksize): torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): @@ -163,10 +163,10 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): class Test4bitBlockwiseQuantOps: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -190,10 +190,10 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -230,10 +230,10 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype) ) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) From 58e989ef989852e98ac11540d1e0b144b4f68783 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 11 Jun 2025 11:43:45 +0530 Subject: [PATCH 073/102] Lint --- bitsandbytes/diagnostics/main.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index ed7999f14..9a0447433 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -6,12 +6,11 @@ import torch from bitsandbytes import __version__ as bnb_version -from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT +from bitsandbytes.cextension import BNB_BACKEND from bitsandbytes.consts import PACKAGE_GITHUB_URL from bitsandbytes.cuda_specs import get_cuda_specs from bitsandbytes.diagnostics.cuda import ( print_diagnostics, - print_runtime_diagnostics, ) from bitsandbytes.diagnostics.utils import print_dedented, print_header @@ -77,18 +76,18 @@ def main(): print_header("") cuda_specs = get_cuda_specs() - + if cuda_specs: print_diagnostics(cuda_specs) # TODO: There's a lot of noise in this; needs improvement. # print_cuda_runtime_diagnostics() - + if not torch.cuda.is_available(): print(f"PyTorch says {BNB_BACKEND} is not available. Possible reasons:") print(f"1. {BNB_BACKEND} driver not installed") - print(f"2. Using a CPU-only PyTorch build") - print(f"3. No GPU detected") + print("2. Using a CPU-only PyTorch build") + print("3. No GPU detected") else: print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") @@ -108,7 +107,7 @@ def main(): raise e except Exception: traceback.print_exc() - + print_dedented( f""" Above we output some debug information. From 2cce3366b363de9499220a00a62e34e88183ced4 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:03:19 +0530 Subject: [PATCH 074/102] Update helpers.py --- tests/helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index fbc4af071..671ea39eb 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,6 +7,8 @@ import torch +from bitsandbytes.cextension import HIP_ENVIRONMENT + test_dims_rng = random.Random(42) @@ -21,7 +23,7 @@ def get_available_devices(): # If the environment variable is set, use it directly. return [os.environ["BNB_TEST_DEVICE"]] - devices = ["cpu"] + devices = [] if HIP_ENVIRONMENT else ["cpu"] if hasattr(torch, "accelerator"): # PyTorch 2.6+ - determine accelerator using agnostic API. From 5eb0316802d87c25e0c850a13c7cec77e9648583 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:15:09 +0530 Subject: [PATCH 075/102] Update test_functional.py --- tests/test_functional.py | 65 ++++++++-------------------------------- 1 file changed, 13 insertions(+), 52 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 571eea55f..a2964c733 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -89,10 +89,7 @@ def reset(self): class Test8BitBlockwiseQuantizeFunctional: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize( @@ -175,10 +172,7 @@ def test_blockwise_cpu_large(self, hidden, blocksize): # print(sum(diffs)/len(diffs)) # print(sum(reldiffs)/len(reldiffs)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) def test_few_bit_quant(self, device, bits, method): @@ -236,10 +230,7 @@ def test_few_bit_quant(self, device, bits, method): else: torch.testing.assert_close(q1, q2) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) def test_fp8_quant(self, device): # TODO if device == "cpu": @@ -571,10 +562,7 @@ def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): class TestLLMInt8Functional: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3")) @@ -593,10 +581,7 @@ def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): C2 = F.int8_linear_matmul(A, B) torch.testing.assert_close(C1, C2.float()) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @@ -620,10 +605,7 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @@ -739,10 +721,7 @@ def test_int8_double_quant(self, dim1, dim2): torch.testing.assert_close(Srow.flatten().float(), statsA) torch.testing.assert_close(Scol.flatten().float(), statsAt) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize( ("dim1", "dim4", "inner"), ( @@ -784,10 +763,7 @@ def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): err2 = torch.abs(out1 - out3).mean().item() assert err2 <= err1 * 1.025 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_double_quant(self, device, dim1, dim2): @@ -807,10 +783,7 @@ def test_coo_double_quant(self, device, dim1, dim2): A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): @@ -1134,10 +1107,7 @@ def test_coo2csc(self): class TestQuantize4BitFunctional: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize( @@ -1176,10 +1146,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 assert err.item() < math.log2(blocksize) * 8e-2 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) def test_4bit_compressed_stats(self, device, quant_type, blocksize): @@ -1249,10 +1216,7 @@ def test_bench_4bit_dequant(self, quant_type): @pytest.mark.skipif( HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" ) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) @@ -1411,10 +1375,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert relratio < 1.04 and relratio > 0.96 assert maxratio < 1.02 and maxratio > 0.98 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) From dcdf2c54ffe023295a1b6f60edab18d60b073552 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:15:41 +0530 Subject: [PATCH 076/102] Update test_linear4bit.py --- tests/test_linear4bit.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index ddc609616..60c163477 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -18,10 +18,7 @@ } -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -185,10 +182,7 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua assert size_ratio < target_compression, ratio_error_msg -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -213,10 +207,7 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -248,10 +239,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): assert dict_keys_before == dict_keys_copy -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) From 6bba74052813946d28b238755d43756ff0e6c4f5 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:16:22 +0530 Subject: [PATCH 077/102] Update test_ops.py --- tests/test_ops.py | 50 ++++++++++------------------------------------- 1 file changed, 10 insertions(+), 40 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 9d406b793..a433a0c4b 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -9,10 +9,7 @@ class TestLLMInt8Ops: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_linear_matmul(self, device): A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -24,10 +21,7 @@ def test_int8_linear_matmul(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_linear_matmul_out(self, device): A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -42,10 +36,7 @@ def test_int8_linear_matmul_out(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out)) @pytest.mark.parametrize("threshold", [0.0, 6.0]) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_vectorwise_quant(self, threshold, device): A = torch.randn(10, 20, dtype=torch.float16, device=device) A[1][0] = 1000.0 @@ -71,10 +62,7 @@ def test_int8_vectorwise_quant(self, threshold, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_mm_dequant(self, device): A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) row_stats = torch.randn(256, dtype=torch.float32, device=device) @@ -87,10 +75,7 @@ def test_int8_mm_dequant(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("has_bias", TRUE_FALSE) def test_int8_scaled_mm(self, device, dtype, has_bias): @@ -109,10 +94,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): @@ -136,10 +118,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize): torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): @@ -163,10 +142,7 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): class Test4bitBlockwiseQuantOps: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -190,10 +166,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -230,10 +203,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype) ) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) From bdd67545ed1c66d1fc7cf5b84118b6bba107755e Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 11 Jun 2025 15:18:19 +0530 Subject: [PATCH 078/102] Lint --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index 671ea39eb..54eec95dc 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,7 +7,7 @@ import torch -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import HIP_ENVIRONMENT test_dims_rng = random.Random(42) From 3db3196e18ac46496dbc60569b0efd1d603b1f53 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 18 Jun 2025 07:08:21 +0530 Subject: [PATCH 079/102] Update pythonInterface.cpp --- csrc/pythonInterface.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 66e96b07f..a8d47b8de 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -315,12 +315,6 @@ void spmm_coo_very_sparse_naive_int8( extern "C" { #if BUILD_CUDA || BUILD_HIP -void cestimate_quantiles_fp32(float* A, float* code, float offset, int n) { - estimateQuantiles_fp32(A, code, offset, n); -} - -void cestimate_quantiles_fp16(half* A, float* code, float offset, int n) { estimateQuantiles_fp16(A, code, offset, n); } - void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) { From 75a654e3e1eacb6ba78b98bc153925377e530bd8 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 18 Jun 2025 07:11:20 +0530 Subject: [PATCH 080/102] lint fix --- csrc/common_hip.cuh | 2 +- csrc/kernels_hip.cuh | 236 +++++++++++++++++---------------- csrc/ops_hip.cuh | 302 +++++++++++++++++++++++-------------------- 3 files changed, 288 insertions(+), 252 deletions(-) diff --git a/csrc/common_hip.cuh b/csrc/common_hip.cuh index 105179535..1d9d9afe0 100644 --- a/csrc/common_hip.cuh +++ b/csrc/common_hip.cuh @@ -1,6 +1,6 @@ #pragma once -#define BNB_WARP_SIZE warpSize +#define BNB_WARP_SIZE warpSize // These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs #define BNB_MAX_THREADS_PER_SM 2048 diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh index 2895012f8..811299d05 100644 --- a/csrc/kernels_hip.cuh +++ b/csrc/kernels_hip.cuh @@ -11,122 +11,136 @@ #ifndef kernels #define kernels - -template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); - -__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); -__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); - -template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); - -template -__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); - -template -__global__ void kOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, - const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - -template -__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, - float* state1, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); - -template -__global__ void kOptimizer32bit1State(T* g, T* p, - float* state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - -template -__global__ void -kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, - float *unorm, - const float beta1, const float beta2, - const float eps, const int step, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, - const float weight_decay, - const float gnorm_scale, const int n); - - -template +template __global__ void -kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, - const float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, - float weight_decay, const float gnorm_scale, const int n); - + kEstimateQuantiles(T* __restrict__ const A, float* code, const float offset, const T max_val, const int n); +__global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n); +__global__ void kDequantize(float* code, unsigned char* A, float* out, const int n); -template +template +__global__ void kQuantizeBlockwise( + float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, + const int rand_offset, const int n +); +template __global__ void -kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, - float *unorm, - const float beta1, const float beta2, - const float eps, const int step, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - const float gnorm_scale, const int n); - + kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n); + +template +__global__ void kPreconditionOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); + +template +__global__ void kPreconditionOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1, + const float beta2, const float eps, const float weight_decay, const int step, const float lr, + const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kPreconditionOptimizerStatic8bit1State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1, + const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, + float* new_max1, const float weight_decay, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit1State( + T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, + const int n +); + +template +__global__ void kPreconditionOptimizerStatic8bit2State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2, + float* unorm, const float beta1, const float beta2, const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit2State( + T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm, + const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, + const float beta3, const float alpha, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, + const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n); -template __global__ void -kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, - const float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, const float gnorm_scale, const int n); - -template __global__ void kOptimizerStatic8bit2StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); - -template __global__ void kOptimizerStatic8bit1StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* absmax1, - float weight_decay, - const float gnorm_scale, const bool skip_zeros, const int n); - - -template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); - - -__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); - - -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); - -template __global__ void kdequant_mm_int32_fp16( - int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, - half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); - -template __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols); -template __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); - -template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); - -template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); -template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); - -template __global__ void kfunc(T *A, T *B, T value, long n); + kHistogramScatterAdd2D(float* histogram, int* index1, int* index2, float* src, const int maxidx1, const int n); + +template +__global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); + +template +__global__ void kdequant_mm_int32_fp16( + int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, + half* __restrict__ const bias, const int numRows, const int numCols, const int n +); + +template +__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols); +template +__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols); + +template +__global__ void kTransformRowToFormat( + char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols +); + +template +__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc); +template +__global__ void kgemm_4bit_inference( + int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, + int blocksize +); +template +__global__ void kgemm_4bit_inference_naive( + int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, + int lda, int ldb, int ldc, int blocksize +); + +template __global__ void kfunc(T* A, T* B, T value, long n); #endif diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index bcfc73e99..624ebe326 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -4,42 +4,42 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. - #ifndef ops_H #define ops_H +#include #include -#include #include +#include #include -#include -#include +#include #include -#include +#include #include #include +#include #include -#include - -#define CUDA_CHECK_RETURN(value) { \ - hipError_t _m_cudaStat = value; \ - if (_m_cudaStat != hipSuccess) { \ - fprintf(stderr, "Error %s at line %d in file %s\n", \ - hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ - exit(1); \ - } } - - -#define CHECK_HIPSPARSE(value) { \ - hipsparseStatus_t _m_hipStat = value; \ - if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \ - fprintf(stderr, "Error %s at line %d in file %s\n", \ - hipsparseGetErrorString(_m_hipStat), __LINE__, __FILE__); \ - exit(1); \ - } } +#define CUDA_CHECK_RETURN(value) \ + { \ + hipError_t _m_cudaStat = value; \ + if (_m_cudaStat != hipSuccess) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } \ + } +#define CHECK_HIPSPARSE(value) \ + { \ + hipsparseStatus_t _m_hipStat = value; \ + if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \ + fprintf( \ + stderr, "Error %s at line %d in file %s\n", hipsparseGetErrorString(_m_hipStat), __LINE__, __FILE__ \ + ); \ + exit(1); \ + } \ + } inline void checkHipStatus(hipError_t status) { if (status != hipSuccess) { @@ -51,145 +51,167 @@ inline void checkHipStatus(hipError_t status) { inline int checkHipblasStatus(hipblasStatus_t status) { if (status != HIPBLAS_STATUS_SUCCESS) { printf("hipBLAS API failed with status %d\n", status); - //throw std::logic_error("cuBLAS API failed"); + // throw std::logic_error("cuBLAS API failed"); return 1; } return 0; } -typedef enum Operations_t -{ - ksmul = 0, +typedef enum Operations_t { + ksmul = 0, } Operations_t; -typedef enum Optimizer_t -{ - ADAM = 0, - MOMENTUM = 1, - RMSPROP = 2, - LARS = 3, - ADAGRAD = 4, - LION = 5, - ADEMAMIX = 6, +typedef enum Optimizer_t { + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, + ADEMAMIX = 6, } Optimizer_t; -typedef enum Transform_t -{ - ROW = 0, - COL = 1, - COL32 = 2, - COL_TURING = 3, - COL_AMPERE = 4, +typedef enum Transform_t { + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, } Transform_t; -typedef enum DataType_t -{ - General8bit = 0, - FP4 = 1, - NF4 = 2, +typedef enum DataType_t { + General8bit = 0, + FP4 = 1, + NF4 = 2, } DataType_t; -typedef enum Funcs_t -{ - FILL = 0, - ARANGE = 1, - _MUL = 2, +typedef enum Funcs_t { + FILL = 0, + ARANGE = 1, + _MUL = 2, } Funcs_t; -class Context -{ - public: - rocblas_handle m_handle; - - Context() - { - rocblas_handle handle; - rocblas_create_handle(&handle); - m_handle = handle; - } - -}; +class Context { + public: + rocblas_handle m_handle; -class ContextLt -{ - public: - hipblasLtHandle_t m_handle; - - ContextLt() - { - hipblasLtHandle_t handle; - hipblasLtCreate(&handle); - m_handle = handle; - } + Context() { + rocblas_handle handle; + rocblas_create_handle(&handle); + m_handle = handle; + } }; -class ContextHipsparse -{ - public: - hipsparseHandle_t m_handle; - - ContextHipsparse() - { - hipsparseHandle_t handle; - hipsparseCreate(&handle); - m_handle = handle; - } +class ContextLt { + public: + hipblasLtHandle_t m_handle; + ContextLt() { + hipblasLtHandle_t handle; + hipblasLtCreate(&handle); + m_handle = handle; + } }; +class ContextHipsparse { + public: + hipsparseHandle_t m_handle; -template void estimateQuantiles(T *A, float *code, float offset, int n); - -void quantize(float *code, float *A, unsigned char *out, int n); -void dequantize(float *code, unsigned char *A, float *out, int n, hipStream_t stream); -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, hipStream_t stream); - -template void optimizer32bit(T* g, T* p, - float* state1, float* state2, float *unorm, float max_unorm, float param_norm, - float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, - int step, float lr, const float gnorm_scale, bool skip_zeros, int n); - -template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, - float *unorm, float max_unorm, float param_norm, - float beta1, float beta2, - float eps, int step, float lr, - float* quantiles1, float* quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, - const float gnorm_scale, int n); - -template void optimizerStatic8bitBlockwise(T* p, T* g, - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, - bool skip_zeros, int n); - -template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); - -void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); - -void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, - long long int strideA, long long int strideB, long long int strideC, int batchCount); - - -template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); - -void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, hipStream_t stream); -void getRowStats(half * A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream); -void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, hipStream_t stream); - -void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); - -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); - -void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); - -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); -template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); + ContextHipsparse() { + hipsparseHandle_t handle; + hipsparseCreate(&handle); + m_handle = handle; + } +}; -template void func(T *A, T *B, T value, long n); +template void estimateQuantiles(T* A, float* code, float offset, int n); + +void quantize(float* code, float* A, unsigned char* out, int n); +void dequantize(float* code, unsigned char* A, float* out, int n, hipStream_t stream); +template +void quantizeBlockwise( + float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, hipStream_t stream +); + +template +void optimizer32bit( + T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2, + float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale, + bool skip_zeros, int n +); + +template +void optimizerStatic8bit( + T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, + float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n +); + +template +void optimizerStatic8bitBlockwise( + T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, + float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, bool skip_zeros, int n +); + +template void percentileClipping(T* g, float* gnorm_vec, int step, const int n); + +void histogramScatterAdd2D(float* histogram, int* index1, int* index2, float* src, int maxidx1, int n); + +void gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc +); +void strided_gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount +); + +template +int igemmlt( + hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, hipStream_t stream +); + +void cutlass_igemm( + bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc +); +void dequant_mm_int32_fp16( + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, hipStream_t stream +); +void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, hipStream_t stream); +void int8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, hipStream_t stream +); + +void spmm_coo( + hipsparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, + int ldb, half* B, int ldc, half* C, bool transposed_B +); + +template +void spmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +); + +void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits); +template +void gemm_4bit_inference( + int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize +); +template +void gemm_4bit_inference_naive( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, hipStream_t stream +); + +template void func(T* A, T* B, T value, long n); #endif From 562473620ee321fd3d74ed42b06775f245b35282 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 18 Jun 2025 14:58:01 +0530 Subject: [PATCH 081/102] lint --- conflicts.diff | 431 ------------------------------------------------- 1 file changed, 431 deletions(-) delete mode 100644 conflicts.diff diff --git a/conflicts.diff b/conflicts.diff deleted file mode 100644 index bab359251..000000000 --- a/conflicts.diff +++ /dev/null @@ -1,431 +0,0 @@ -diff --cc .github/workflows/python-package.yml -index 3673ac6,d3deb26..0000000 ---- a/.github/workflows/python-package.yml -+++ b/.github/workflows/python-package.yml -@@@ -218,7 -173,14 +218,18 @@@ jobs - merge-multiple: true - - - name: Inspect tmp directory after downloading artifacts -++<<<<<<< HEAD - + run: ls -alFR tmp/ -++======= -+ run: | -+ ls -alFR tmp/ -+ WHEEL_COUNT=$(find tmp/ -type f -name "*.whl" | wc -l) -+ echo "Found $WHEEL_COUNT wheel files" -+ if [ "$WHEEL_COUNT" -eq 0 ]; then -+ echo "::error::No wheel files found in tmp directory! Cannot proceed with release." -+ exit 1 -+ fi -++>>>>>>> upstream/main - - - name: Move and rename wheel files with pattern replacement - run: | -@@@ -245,9 -207,20 +256,23 @@@ - - name: Inspect wheels directory after renaming files - run: ls -alFR wheels/ - -++<<<<<<< HEAD -++======= -+ - uses: actions/checkout@v4 -+ with: -+ path: repo -++>>>>>>> upstream/main - - name: Delete old pre-release (if exists) - run: | -- gh release delete continuous-release_main --cleanup-tag -y || true -+ cd repo && gh release delete continuous-release_main --cleanup-tag -y -+ env: -+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -+ -+ - name: Ensure tag exists -+ run: | -+ cd repo -+ git tag -f continuous-release_main -+ git push -f origin continuous-release_main - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - -diff --cc bitsandbytes/cextension.py -index 5283df9,b112df2..0000000 ---- a/bitsandbytes/cextension.py -+++ b/bitsandbytes/cextension.py -@@@ -28,17 -28,10 +29,15 @@@ def get_cuda_bnb_library_path(cuda_spec - override_value = os.environ.get("BNB_CUDA_VERSION") - if override_value: - library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) - + if torch.version.hip: - + raise RuntimeError( - + f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" - + f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" - + ) - logger.warning( - f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" -- "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" -+ "This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n" - "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" -- "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" -- "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: BNBNativeLi - return BNBNativeLibrary(dll) - - - +ROCM_GPU_ARCH = get_rocm_gpu_arch() - + - try: -++<<<<<<< HEAD - + if torch.version.hip: - + HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" - + else: - + HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" - + -++======= -+ # to support Intel CPU/GPU (XPU) backend -+ import intel_extension_for_pytorch as ipex -+ -+ ipex_cpu = ipex if ipex._C._has_cpu() else None -+ ipex_xpu = ipex if ipex._C._has_xpu() else None -+ except BaseException: -+ ipex_cpu = None -+ ipex_xpu = None -+ -+ -+ try: -++>>>>>>> upstream/main - lib = get_native_library() - except Exception as e: - error_msg = str(e) -diff --cc bitsandbytes/diagnostics/cuda.py -index b9db101,e763ef2..0000000 ---- a/bitsandbytes/diagnostics/cuda.py -+++ b/bitsandbytes/diagnostics/cuda.py -@@@ -5,8 -5,7 +5,12 @@@ from pathlib import Pat - - import torch - -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path - +from bitsandbytes.consts import NONPYTORCH_DOC_URL -++======= -+ from bitsandbytes.cextension import get_cuda_bnb_library_path -++>>>>>>> upstream/main - from bitsandbytes.cuda_specs import CUDASpecs - from bitsandbytes.diagnostics.utils import print_dedented - -@@@ -148,42 -127,8 +136,38 @@@ def _print_cuda_diagnostics(cuda_specs - """, - ) - -- # TODO: -- # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) -- # (2) Multiple CUDA versions installed -- - - -def print_cuda_runtime_diagnostics() -> None: - +def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: - + print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") - + - + binary_path = get_cuda_bnb_library_path(cuda_specs) - + if not binary_path.exists(): - + print_dedented( - + f""" - + Library not found: {binary_path}. - + Maybe you need to compile it from source? If you compiled from source, check that ROCm version - + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version - + and rebuild bitsandbytes. - + """, - + ) - + - + hip_major, hip_minor = cuda_specs.cuda_version_tuple - + if (hip_major, hip_minor) < (6, 1): - + print_dedented( - + """ - + WARNING: bitsandbytes is fully supported only from ROCm 6.1. - + """, - + ) - + - + - +def print_diagnostics(cuda_specs: CUDASpecs) -> None: - + if HIP_ENVIRONMENT: - + _print_hip_diagnostics(cuda_specs) - + else: - + _print_cuda_diagnostics(cuda_specs) - + - + - +def _print_cuda_runtime_diagnostics() -> None: - cudart_paths = list(find_cudart_libraries()) - if not cudart_paths: - print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") -diff --cc bitsandbytes/diagnostics/main.py -index bf31d79,aa4cb30..0000000 ---- a/bitsandbytes/diagnostics/main.py -+++ b/bitsandbytes/diagnostics/main.py -@@@ -3,12 -5,11 +5,20 @@@ import tracebac - - import torch - -++<<<<<<< HEAD - +from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT - +from bitsandbytes.consts import PACKAGE_GITHUB_URL - +from bitsandbytes.cuda_specs import get_cuda_specs - +from bitsandbytes.diagnostics.cuda import ( - + print_diagnostics, - + print_runtime_diagnostics, -++======= -+ from bitsandbytes import __version__ as bnb_version -+ from bitsandbytes.consts import PACKAGE_GITHUB_URL -+ from bitsandbytes.cuda_specs import get_cuda_specs -+ from bitsandbytes.diagnostics.cuda import ( -+ print_cuda_diagnostics, -++>>>>>>> upstream/main - ) - from bitsandbytes.diagnostics.utils import print_dedented, print_header - -@@@ -28,53 -41,77 +50,123 @@@ def sanity_check() - assert p1 != p2 - - -+ def get_package_version(name: str) -> str: -+ try: -+ version = importlib.metadata.version(name) -+ except importlib.metadata.PackageNotFoundError: -+ version = "not found" -+ return version -+ -+ -+ def show_environment(): -+ """Simple utility to print out environment information.""" -+ -+ print(f"Platform: {platform.platform()}") -+ if platform.system() == "Linux": -+ print(f" libc: {'-'.join(platform.libc_ver())}") -+ -+ print(f"Python: {platform.python_version()}") -+ -+ print(f"PyTorch: {torch.__version__}") -+ print(f" CUDA: {torch.version.cuda or 'N/A'}") -+ print(f" HIP: {torch.version.hip or 'N/A'}") -+ print(f" XPU: {getattr(torch.version, 'xpu', 'N/A') or 'N/A'}") -+ -+ print("Related packages:") -+ for pkg in _RELATED_PACKAGES: -+ version = get_package_version(pkg) -+ print(f" {pkg}: {version}") -+ -+ - def main(): -- print_header("") -- print_header("BUG REPORT INFORMATION") -+ print_header(f"bitsandbytes v{bnb_version}") -+ show_environment() - print_header("") - -- print_header("OTHER") - cuda_specs = get_cuda_specs() -++<<<<<<< HEAD - + if HIP_ENVIRONMENT: - + rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," - + rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" - + print(f"{BNB_BACKEND} specs:{rocm_specs}") - + else: - + print(f"{BNB_BACKEND} specs:{cuda_specs}") - + if not torch.cuda.is_available(): - + print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") - + if not HIP_ENVIRONMENT: - + print(f"- {BNB_BACKEND} driver not installed") - + print(f"- {BNB_BACKEND} not installed") - + print(f"- You have multiple conflicting {BNB_BACKEND} libraries") - + if cuda_specs: - + print_diagnostics(cuda_specs) - + print_runtime_diagnostics() - + print_header("") - + print_header("DEBUG INFO END") - + print_header("") - + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") - + try: - + sanity_check() - + print("SUCCESS!") - + print("Installation was successful!") - + return - + except RuntimeError as e: - + if "not available in CPU-only" in str(e): - + print( - + f"WARNING: {__package__} is currently running as CPU-only!\n" - + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - + f"If you think that this is so erroneously,\nplease report an issue!", - + ) - + else: - + raise e - + except Exception: - + traceback.print_exc() - + print_dedented( - + f""" - + Above we output some debug information. - + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose - + WARNING: Please be sure to sanitize sensitive info from the output before posting it. - + """, - + ) - + sys.exit(1) -++======= -+ -+ if cuda_specs: -+ print_cuda_diagnostics(cuda_specs) -+ -+ # TODO: There's a lot of noise in this; needs improvement. -+ # print_cuda_runtime_diagnostics() -+ -+ if not torch.cuda.is_available(): -+ print("PyTorch says CUDA is not available. Possible reasons:") -+ print("1. CUDA driver not installed") -+ print("2. Using a CPU-only PyTorch build") -+ print("3. No GPU detected") -+ -+ else: -+ print("Checking that the library is importable and CUDA is callable...") -+ -+ try: -+ sanity_check() -+ print("SUCCESS!") -+ return -+ except RuntimeError as e: -+ if "not available in CPU-only" in str(e): -+ print( -+ f"WARNING: {__package__} is currently running as CPU-only!\n" -+ "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" -+ f"If you think that this is so erroneously,\nplease report an issue!", -+ ) -+ else: -+ raise e -+ except Exception: -+ traceback.print_exc() -+ -+ print_dedented( -+ f""" -+ Above we output some debug information. -+ Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose -+ WARNING: Please be sure to sanitize sensitive info from the output before posting it. -+ """, -+ ) -+ sys.exit(1) -++>>>>>>> upstream/main -diff --cc bitsandbytes/functional.py -index 9b7ce2d,ffb6668..0000000 -mode 100644,100755..100755 ---- a/bitsandbytes/functional.py -+++ b/bitsandbytes/functional.py -@@@ -13,9 -13,9 +13,13 @@@ import torc - from torch import Tensor - from typing_extensions import deprecated - -- from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -+ from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict - -++<<<<<<< HEAD - +from .cextension import HIP_ENVIRONMENT, lib -++======= -+ from .cextension import ipex_cpu, ipex_xpu, lib -++>>>>>>> upstream/main - - name2qmap = {} - -diff --cc bitsandbytes/nn/modules.py -index a2facac,e349cc8..0000000 ---- a/bitsandbytes/nn/modules.py -+++ b/bitsandbytes/nn/modules.py -@@@ -11,8 -11,7 +11,12 @@@ from torch import Tensor, device, dtype - import torch.nn.functional as F - - import bitsandbytes as bnb -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT - +from bitsandbytes.functional import QuantState -++======= -+ from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu -++>>>>>>> upstream/main - from bitsandbytes.optim import GlobalOptimManager - from bitsandbytes.utils import ( - INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, -diff --cc tests/test_linear4bit.py -index 60c1634,b5db2eb..0000000 ---- a/tests/test_linear4bit.py -+++ b/tests/test_linear4bit.py -@@@ -7,8 -8,14 +8,19 @@@ import pytes - import torch - - import bitsandbytes as bnb -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT - +from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer -++======= -+ from tests.helpers import ( -+ TRUE_FALSE, -+ describe_dtype, -+ get_available_devices, -+ id_formatter, -+ torch_load_from_buffer, -+ torch_save_to_buffer, -+ ) -++>>>>>>> upstream/main - - storage = { - "uint8": torch.uint8, -@@@ -184,16 -185,10 +190,10 @@@ def test_linear_serialization(device, q - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_copy_param(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- tensor = torch.linspace(1, blocksize, blocksize) -+ tensor = torch.randn(300, 400) - param = bnb.nn.Params4bit( - data=tensor, - quant_type=quant_type, -@@@ -209,16 -204,10 +209,10 @@@ - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- tensor = torch.linspace(1, blocksize, blocksize) -+ tensor = torch.randn(300, 400) - param = bnb.nn.Params4bit( - data=tensor, - quant_type=quant_type, -@@@ -241,16 -230,10 +235,10 @@@ - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32) -+ original_tensor = torch.randn(300, 400) - original_param = bnb.nn.Params4bit( - data=original_tensor, - quant_type=quant_type, From c75fdb7d52feb7d4b11a0e1141b91c50a1c04c4e Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 18 Jun 2025 15:02:59 +0530 Subject: [PATCH 082/102] Update pythonInterface.cpp --- csrc/pythonInterface.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index a8d47b8de..9c4cab9cc 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -37,11 +37,6 @@ //=================================================================================== #if BUILD_CUDA || BUILD_HIP -void estimateQuantiles_fp32(float* A, float* code, float offset, int n) { - estimateQuantiles(A, code, offset, n); -} - -void estimateQuantiles_fp16(half* A, float* code, float offset, int n) { estimateQuantiles(A, code, offset, n); } // void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } From 3936ca40bffa149bb871b753e5536dcd3ab96817 Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Wed, 18 Jun 2025 12:09:27 -0500 Subject: [PATCH 083/102] revert permissions change --- bitsandbytes/functional.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 bitsandbytes/functional.py diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py old mode 100755 new mode 100644 From b4fd5942b07d65a2084656bc79221caab4d7f3fa Mon Sep 17 00:00:00 2001 From: pnunna93 <104791500+pnunna93@users.noreply.github.com> Date: Wed, 18 Jun 2025 12:31:24 -0500 Subject: [PATCH 084/102] Fix indentation --- bitsandbytes/diagnostics/main.py | 48 ++++++++++++++++---------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 9a0447433..74da662b6 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -92,27 +92,27 @@ def main(): else: print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") - try: - sanity_check() - print("SUCCESS!") - return - except RuntimeError as e: - if "not available in CPU-only" in str(e): - print( - f"WARNING: {__package__} is currently running as CPU-only!\n" - "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - f"If you think that this is so erroneously,\nplease report an issue!", - ) - else: - raise e - except Exception: - traceback.print_exc() - - print_dedented( - f""" - Above we output some debug information. - Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose - WARNING: Please be sure to sanitize sensitive info from the output before posting it. - """, - ) - sys.exit(1) + try: + sanity_check() + print("SUCCESS!") + return + except RuntimeError as e: + if "not available in CPU-only" in str(e): + print( + f"WARNING: {__package__} is currently running as CPU-only!\n" + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + f"If you think that this is so erroneously,\nplease report an issue!", + ) + else: + raise e + except Exception: + traceback.print_exc() + + print_dedented( + f""" + Above we output some debug information. + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose + WARNING: Please be sure to sanitize sensitive info from the output before posting it. + """, + ) + sys.exit(1) From 3228ca86d74a50d4f7c5170bc473d29c30f3dec5 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 14:15:25 +0530 Subject: [PATCH 085/102] Update kernels_hip.cuh --- csrc/kernels_hip.cuh | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh index 811299d05..d902129a3 100644 --- a/csrc/kernels_hip.cuh +++ b/csrc/kernels_hip.cuh @@ -103,9 +103,6 @@ __global__ void kOptimizerStatic8bit1StateBlockwise( template __global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n); -__global__ void - kHistogramScatterAdd2D(float* histogram, int* index1, int* index2, float* src, const int maxidx1, const int n); - template __global__ void kspmm_coo_very_sparse_naive( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, From 94c1b7751bdd1d10014cf861a4e28ede66262530 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 14:21:11 +0530 Subject: [PATCH 086/102] Update kernels.hip --- csrc/kernels.hip | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 56e1d54db..53b2725a3 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -346,18 +346,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran } } -__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) -{ - const int tid = threadIdx.x + (blockDim.x*blockIdx.x); - const int numThreads = blockDim.x*gridDim.x; - - for(int i = tid; i < n; i+=numThreads) - { - int idx = (index1[i]*maxidx1) + index2[i]; - atomicAdd(&histogram[idx], src[i]); - } -} - #define THREADS_ESTIMATE 512 #define NUM_ESTIMATE 8 #define BLOCK_ESTIMATE 4096 From cd3f0b779f6c285cd969689dd509ad08698e0964 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 14:23:14 +0530 Subject: [PATCH 087/102] Update ops.hip --- csrc/ops.hip | 9 --------- 1 file changed, 9 deletions(-) diff --git a/csrc/ops.hip b/csrc/ops.hip index a9c3e0202..ccdbc1026 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -24,15 +24,6 @@ using namespace BinSearch; using std::cout; using std::endl; -void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) -{ - int threads = 512; - int num_blocks = n/threads; - num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kHistogramScatterAdd2D), dim3(num_blocks), dim3(512), 0, 0, histogram, index1, index2, src, maxidx1, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} - template void estimateQuantiles(T *A, float *code, float offset, int n) { int num_blocks = n/4096; From 98bb06ed6245da3af44497c1df04c8da06f00d2a Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 14:25:32 +0530 Subject: [PATCH 088/102] Update ops_hip.cuh --- csrc/ops_hip.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index 624ebe326..ebae292c4 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -160,8 +160,6 @@ void optimizerStatic8bitBlockwise( template void percentileClipping(T* g, float* gnorm_vec, int step, const int n); -void histogramScatterAdd2D(float* histogram, int* index1, int* index2, float* src, int maxidx1, int n); - void gemmex( Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc From 3bad4541e3d9fc186cf680009bfef7c980bb0aaa Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:17:59 +0530 Subject: [PATCH 089/102] Update kernels_hip.cuh --- csrc/kernels_hip.cuh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh index d902129a3..00718071c 100644 --- a/csrc/kernels_hip.cuh +++ b/csrc/kernels_hip.cuh @@ -11,10 +11,6 @@ #ifndef kernels #define kernels -template -__global__ void - kEstimateQuantiles(T* __restrict__ const A, float* code, const float offset, const T max_val, const int n); - __global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n); __global__ void kDequantize(float* code, unsigned char* A, float* out, const int n); From e0c766dcc34b6147d5a6e8aa505dbb15c08233a5 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:20:37 +0530 Subject: [PATCH 090/102] Update kernels.hip --- csrc/kernels.hip | 73 ------------------------------------------------ 1 file changed, 73 deletions(-) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 53b2725a3..6b0f1eac5 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -346,79 +346,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran } } -#define THREADS_ESTIMATE 512 -#define NUM_ESTIMATE 8 -#define BLOCK_ESTIMATE 4096 - -template -__launch_bounds__(THREADS_ESTIMATE, 1) -__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n) -{ - const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); - int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; - const int base_idx = (blockIdx.x * BLOCK_ESTIMATE); - const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); - - T vals[NUM_ESTIMATE]; - - typedef hipcub::BlockRadixSort BlockRadixSort; - typedef hipcub::BlockLoad LoadFloat; - - __shared__ union { - typename LoadFloat::TempStorage loadf; - typename BlockRadixSort::TempStorage sort; - int smem_qidx[BLOCK_ESTIMATE]; - } temp_storage; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE) - { - valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; - - // do not process half-blocks - if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } - - #pragma unroll 4 - for(int j = 0; j < NUM_ESTIMATE; j++) - vals[j] = max_val; - - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); - - #pragma unroll 4 - for(int j = 0; j < NUM_ESTIMATE; j++) - vals[j] = ((float)vals[j]) * reciprocal_num_blocks; - - - __syncthreads(); - // sort into striped pattern to mitigate bank conflicts - // striped pattern index for thread 0 [0, 1024, 2048, 3096] - // striped pattern index for thread 1 [1, 1025, 2049, 3097] - BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); - - __syncthreads(); - for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) - temp_storage.smem_qidx[j] = -1; - - __syncthreads(); - - if(threadIdx.x < 256) - { - float q_interval = (1.0f-(2.0f*offset))/255.0f; - int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1))); - temp_storage.smem_qidx[local_idx] = threadIdx.x; - } - - __syncthreads(); - - for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x) - { - if(temp_storage.smem_qidx[i] != -1) - atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); - } - } -} - - __launch_bounds__(TH, 4) __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) { From f35a063db5bd5fb87c0ccf70df2687b7079b33af Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:22:55 +0530 Subject: [PATCH 091/102] Update kernels.hip --- csrc/kernels.hip | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 6b0f1eac5..ec3f7f025 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -2899,9 +2899,6 @@ template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); -template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n); -template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n); - #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ From fca01f310358169d49b686bce1fae7a9c4d37c93 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:30:34 +0530 Subject: [PATCH 092/102] Update ops.hip --- csrc/ops.hip | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/ops.hip b/csrc/ops.hip index ccdbc1026..1840b7864 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -743,9 +743,6 @@ template int igemmlt<32, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, con template int igemmlt<8, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); template int igemmlt<8, 1>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); -template void estimateQuantiles(half *A, float *code, float offset, int n); -template void estimateQuantiles(float *A, float *code, float offset, int n); - template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); From 5569c2de672006ed6353cf85e0a34b4ddeec59a1 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:34:01 +0530 Subject: [PATCH 093/102] Update ops_hip.cuh --- csrc/ops_hip.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index ebae292c4..0f8db2ee4 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -124,8 +124,6 @@ class ContextHipsparse { } }; -template void estimateQuantiles(T* A, float* code, float offset, int n); - void quantize(float* code, float* A, unsigned char* out, int n); void dequantize(float* code, unsigned char* A, float* out, int n, hipStream_t stream); template From 7a17f2d6f7ecfb78cf72d94de4b3f3f3ef4e1453 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:44:51 +0530 Subject: [PATCH 094/102] Update ops.hip --- csrc/ops.hip | 9 --------- 1 file changed, 9 deletions(-) diff --git a/csrc/ops.hip b/csrc/ops.hip index 1840b7864..260b74b30 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -24,15 +24,6 @@ using namespace BinSearch; using std::cout; using std::endl; -template void estimateQuantiles(T *A, float *code, float offset, int n) -{ - int num_blocks = n/4096; - num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - CUDA_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float))); - hipLaunchKernelGGL(( kEstimateQuantiles), dim3(num_blocks), dim3(512), 0, 0, A, code, offset, std::numeric_limits::max(), n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} - void quantize(float *code, float *A, unsigned char *out, int n) { int num_blocks = n/1024; From 6b8239e707ba7e63bdf3abbac7d365c0a6a0dbfb Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 16:33:09 +0530 Subject: [PATCH 095/102] Update CMakeLists.txt --- CMakeLists.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8a7583279..770b4ba30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,9 +195,6 @@ elseif(BUILD_HIP) string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}") string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}") - if(HIP_VERSION VERSION_LESS "6.1") - string(APPEND BNB_OUTPUT_NAME "_nohipblaslt") - endif() add_compile_definitions(__HIP_PLATFORM_AMD__) add_compile_definitions(__HIP_PLATFORM_HCC__) add_compile_definitions(BUILD_HIP) From 00ac146878bf64ac12c923aaae7ec00283f0ecde Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 16:48:31 +0530 Subject: [PATCH 096/102] Update functional.py --- bitsandbytes/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 3c0a41351..9b446a2de 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -908,7 +908,7 @@ def quantize_4bit( absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 64. + The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. @@ -1019,7 +1019,7 @@ def dequantize_4bit( Required if `quant_state` is not provided and ignored otherwise. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 64. + The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. From 77f4c7747c6354b841f75442f13c2b595bee1a96 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 17:29:43 +0530 Subject: [PATCH 097/102] Update cextension.py --- bitsandbytes/cextension.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 7f5483531..1c5197647 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -23,8 +23,6 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: """ prefix = "rocm" if torch.version.hip else "cuda" - blas_suffix = "_nohipblaslt" if torch.version.hip and cuda_specs.cuda_version_tuple < (6, 1) else "" - library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{blas_suffix}{DYNAMIC_LIBRARY_SUFFIX}" override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: From c9fe2845a5bf440dfb32ccd0680f6dda41ad8096 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 17:43:24 +0530 Subject: [PATCH 098/102] Update cextension.py --- bitsandbytes/cextension.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 1c5197647..bb301e712 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -23,6 +23,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: """ prefix = "rocm" if torch.version.hip else "cuda" + library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: From 7d4854edcee2f20944466729e944d40809d4554c Mon Sep 17 00:00:00 2001 From: sstamenk Date: Fri, 25 Jul 2025 15:40:13 +0200 Subject: [PATCH 099/102] warpSize is being made non constexpr in ROCm 7.0 --- csrc/kernels.hip | 20 ++++++++++++-------- csrc/ops.hip | 8 +++++++- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index ec3f7f025..58f6ed065 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -2109,7 +2109,11 @@ __global__ void kdequant_mm_int32_fp16( #define DENORM 1.0f/127.0f #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8*256 -#define WARP_SIZE warpSize +#if defined(__GFX9__) + #define WARP_SIZE 64 +#else + #define WARP_SIZE 32 +#endif template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) { @@ -2708,13 +2712,13 @@ template __global__ void kgemm_4bit_inferenc // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block - typedef hipcub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize]; + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE]; - const int warp_idx = threadIdx.x / warpSize; - const int warp_lane = threadIdx.x % warpSize; - const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx; - const int offset_B = ldb*row_B; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int warp_lane = threadIdx.x % WARP_SIZE; + const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx; + const int offset_B = ldb * row_B; const int num_values_8bit = num_values_4bit/2; float local_C = 0.0f; @@ -2732,7 +2736,7 @@ template __global__ void kgemm_4bit_inferenc // A: [1, K] // B: [M, K] - for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit) + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit) { const int inner_idx_halved = inner_idx/2; diff --git a/csrc/ops.hip b/csrc/ops.hip index 260b74b30..b26d138e1 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -20,6 +20,12 @@ #define ERR_NOT_IMPLEMENTED 100 +#if defined(__GFX9__) + #define WARP_SIZE 64 +#else + #define WARP_SIZE 32 +#endif + using namespace BinSearch; using std::cout; using std::endl; @@ -692,7 +698,7 @@ template void gemm_4bit_inference_naive(int m, int n, int //warpsize - 32 int num_blocks = (m+3)/4; //warpsize - 64 - if (warpSize == 64) { + if (WARP_SIZE == 64) { num_blocks = (m+1)/2; } From 2e65b38ca42ccfa7ad527de6b2be90ea2d53fc46 Mon Sep 17 00:00:00 2001 From: pnunna93 <104791500+pnunna93@users.noreply.github.com> Date: Tue, 23 Sep 2025 21:45:47 -0500 Subject: [PATCH 100/102] Merge pull request #90 from ROCm/IFU-rocm_enabled-09-23-2025 Ifu rocm enabled 09 23 2025 --- .github/FUNDING.yml | 1 + .github/scripts/build-cuda.sh | 12 +- .github/workflows/python-package.yml | 7 +- .github/workflows/tests.yml | 62 +- CMakeLists.txt | 32 +- MANIFEST.in | 3 + README.md | 36 +- _typos.toml | 7 + benchmarking/inference_benchmark.py | 101 +- benchmarking/xpu/inference_benchmark.py | 147 +++ bitsandbytes/__init__.py | 3 +- bitsandbytes/_ops.py | 117 +- bitsandbytes/autograd/_functions.py | 22 +- bitsandbytes/backends/cpu/ops.py | 171 ++- bitsandbytes/backends/cuda/ops.py | 226 ++++ bitsandbytes/backends/default/ops.py | 277 +++- bitsandbytes/backends/hpu/ops.py | 11 +- .../{triton_kernels.py => kernels_4bit.py} | 165 +-- .../backends/triton/kernels_8bit_quant.py | 195 +++ bitsandbytes/backends/triton/kernels_optim.py | 1154 +++++++++++++++++ bitsandbytes/backends/triton/ops.py | 227 +++- bitsandbytes/backends/utils.py | 12 +- bitsandbytes/backends/xpu/__init__.py | 0 bitsandbytes/backends/xpu/ops.py | 221 +++- bitsandbytes/cextension.py | 42 +- bitsandbytes/functional.py | 280 +--- bitsandbytes/nn/modules.py | 85 +- bitsandbytes/nn/parametrize.py | 192 +++ bitsandbytes/optim/adamw.py | 18 +- bitsandbytes/optim/lars.py | 3 - bitsandbytes/optim/optimizer.py | 24 +- bitsandbytes/py.typed | 0 bitsandbytes/research/autograd/_functions.py | 2 +- bitsandbytes/triton/triton_utils.py | 7 +- bitsandbytes/utils.py | 20 +- csrc/kernels.cu | 132 +- csrc/kernels.hip | 132 +- csrc/pythonInterface.cpp | 169 +++ csrc/xpu_kernels.cpp | 281 ++++ csrc/xpu_kernels.h | 52 + csrc/xpu_ops.cpp | 102 ++ csrc/xpu_ops.h | 46 + docs/source/installation.mdx | 52 +- install_cuda.py | 2 +- pyproject.toml | 14 +- setup.py | 28 +- tests/conftest.py | 2 + tests/helpers.py | 7 +- tests/test_functional.py | 134 +- tests/test_generation.py | 2 +- tests/test_linear4bit.py | 35 + tests/test_linear8bitlt.py | 9 +- tests/test_modules.py | 15 +- tests/test_ops.py | 6 +- tests/test_optim.py | 69 +- tests/test_parametrize.py | 411 ++++++ 56 files changed, 4442 insertions(+), 1140 deletions(-) create mode 100644 .github/FUNDING.yml create mode 100644 MANIFEST.in create mode 100644 benchmarking/xpu/inference_benchmark.py rename bitsandbytes/backends/triton/{triton_kernels.py => kernels_4bit.py} (78%) create mode 100644 bitsandbytes/backends/triton/kernels_8bit_quant.py create mode 100644 bitsandbytes/backends/triton/kernels_optim.py mode change 100755 => 100644 bitsandbytes/backends/utils.py mode change 100755 => 100644 bitsandbytes/backends/xpu/__init__.py mode change 100755 => 100644 bitsandbytes/backends/xpu/ops.py create mode 100644 bitsandbytes/nn/parametrize.py create mode 100644 bitsandbytes/py.typed create mode 100644 csrc/xpu_kernels.cpp create mode 100644 csrc/xpu_kernels.h create mode 100644 csrc/xpu_ops.cpp create mode 100644 csrc/xpu_ops.h create mode 100644 tests/test_parametrize.py diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..8e5903655 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +open_collective: bitsandbytes diff --git a/.github/scripts/build-cuda.sh b/.github/scripts/build-cuda.sh index 8985327f2..b13d9c92b 100644 --- a/.github/scripts/build-cuda.sh +++ b/.github/scripts/build-cuda.sh @@ -11,14 +11,14 @@ if [[ -v cuda_targets ]]; then elif [ "${build_arch}" = "aarch64" ]; then build_capability="75;80;90" - # CUDA 12.8: Add sm100 - [[ "${cuda_version}" == 12.8.* ]] && build_capability="75;80;90;100" + # CUDA 12.8+: Add sm100/sm120 + [[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;90;100;120" else - # By default, target Maxwell through Hopper. - build_capability="50;52;60;61;70;75;80;86;89;90" + # By default, target Pascal through Hopper. + build_capability="60;70;75;80;86;89;90" - # CUDA 12.8: Add sm100 and sm120; remove < sm75 to align with PyTorch 2.7+cu128 minimum - [[ "${cuda_version}" == 12.8.* ]] && build_capability="75;80;86;89;90;100;120" + # CUDA 12.8+: Add sm100 and sm120; remove < sm70 to align with PyTorch 2.8+cu128 minimum + [[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="70;75;80;86;89;90;100;120" fi [[ "${build_os}" = windows-* ]] && python3 -m pip install ninja diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 827c2ffbf..a11b13f33 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -72,16 +72,17 @@ jobs: - os: windows-latest arch: x86_64 cuda_version: - ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1"] + ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1", "12.9.1"] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 # Windows: We install Cuda on the agent (slow) - - uses: Jimver/cuda-toolkit@v0.2.22 + - uses: Jimver/cuda-toolkit@c35baa1a18fd1fc9dcf47c5bd839bf30559c0bc3 # v0.2.24 if: startsWith(matrix.os, 'windows') id: cuda-toolkit with: - cuda: ${{ matrix.cuda_version }} + # Temporary: Use CUDA 12.9.0 for Windows until 12.9.1 is supported with this action. + cuda: ${{ matrix.cuda_version == '12.9.1' && '12.9.0' || matrix.cuda_version }} method: "network" sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]' linux-local-args: '["--toolkit"]' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0d3884593..997da52bd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -49,8 +49,8 @@ jobs: build-cuda: strategy: matrix: - cuda_version: ["11.8.0", "12.6.3", "12.8.1"] - os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025] + cuda_version: ["11.8.0", "12.6.3", "12.8.1", "12.9.1"] + os: [ubuntu-22.04, ubuntu-22.04-arm] include: - os: ubuntu-22.04 arch: x86_64 @@ -58,13 +58,14 @@ jobs: arch: aarch64 - os: windows-2025 arch: x86_64 + cuda_version: "11.8.0" runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - name: Install CUDA Toolkit - uses: Jimver/cuda-toolkit@v0.2.23 + uses: Jimver/cuda-toolkit@c35baa1a18fd1fc9dcf47c5bd839bf30559c0bc3 # v0.2.24 if: startsWith(matrix.os, 'windows') id: cuda-toolkit with: @@ -100,8 +101,8 @@ jobs: fail-fast: false matrix: os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025, macos-15] - # Test with the oldest supported torch version and the two newest. - torch_version: ["2.2.2", "2.6.0", "2.7.1"] + # Test with the oldest supported torch version, the newest two stable/RC. + torch_version: ["2.3.1", "2.7.1", "2.8.0"] include: - os: ubuntu-22.04 arch: x86_64 @@ -117,7 +118,7 @@ jobs: arch: arm64 exclude: - os: ubuntu-22.04-arm - torch_version: "2.2.2" + torch_version: "2.3.1" runs-on: ${{ matrix.runner || matrix.os }} env: @@ -147,9 +148,10 @@ jobs: pip install -e ".[test]" pip install pytest-cov - # We need to downgrade to numpy<2 for torch<2.3 compatibility. + # We need to downgrade to numpy<2 for torch<2.4.1 compatibility on Windows + # See: https://github.com/pytorch/pytorch/issues/131668 - name: Downgrade NumPy - if: startsWith(matrix.torch_version, '2.2.') + if: startsWith(matrix.os, 'windows') && startsWith(matrix.torch_version, '2.3.') run: pip install "numpy<2" - name: Show installed packages @@ -161,7 +163,7 @@ jobs: - name: Run tests run: pytest --durations=100 - test-cpu-ipex: + test-cpu-intel: if: github.repository == 'bitsandbytes-foundation/bitsandbytes' needs: build-cpu runs-on: banb-aws-general-8-plus-use1-public-80 @@ -185,7 +187,6 @@ jobs: - name: Install dependencies run: | pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu - pip install intel_extension_for_pytorch==2.7.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ pip install -e ".[test]" pip install pytest-cov @@ -195,9 +196,6 @@ jobs: - name: Show environment information run: python -m torch.utils.collect_env - - name: IPEX smoke test - run: python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__);" - - name: Run tests run: pytest --durations=100 @@ -223,7 +221,7 @@ jobs: # run: pip list test-hpu: - if: github.repository == 'bitsandbytes-foundation/bitsandbytes' + if: false # github.repository == 'bitsandbytes-foundation/bitsandbytes' needs: build-cpu strategy: fail-fast: false @@ -279,21 +277,12 @@ jobs: run: pytest --durations=100 test-xpu: - if: github.repository == 'bitsandbytes-foundation/bitsandbytes' + if: false # github.repository == 'bitsandbytes-foundation/bitsandbytes' needs: build-cpu strategy: fail-fast: false matrix: torch_version: ["2.7.1"] #["2.6.0", "2.7.1"] - ipex: [false] - # ipex: [true, false] - # include: - # - torch_version: "2.6.0" - # ipex: true - # ipex_version: "2.6.10+xpu" - # - torch_version: "2.7.1" - # ipex: true - # ipex_version: "2.7.10+xpu" runs-on: group: bandb-itac-bmsprpvc1550-8-1gpu env: @@ -329,10 +318,6 @@ jobs: - name: Install PyTorch run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu - - name: Install IPEX - if: matrix.ipex == true - run: pip install intel_extension_for_pytorch==${{ matrix.ipex_version }} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ - - name: Install dependencies run: | pip install -e ".[test]" @@ -358,10 +343,10 @@ jobs: os: [ubuntu-22.04, windows-2025] arch: [x86_64] gpu: [T4, L40S] - cuda_version: ["11.8.0", "12.6.3", "12.8.1"] + cuda_version: ["11.8.0", "12.6.3", "12.8.1", "12.9.1"] include: - cuda_version: "11.8.0" - torch_version: "2.2.2" + torch_version: "2.3.1" pypi_index: "https://download.pytorch.org/whl/cu118" - cuda_version: "12.6.3" torch_version: "2.6.0" @@ -369,6 +354,9 @@ jobs: - cuda_version: "12.8.1" torch_version: "2.7.1" pypi_index: "https://download.pytorch.org/whl/cu128" + - cuda_version: "12.9.1" + torch_version: "2.8.0" + pypi_index: "https://download.pytorch.org/whl/cu129" # Linux L40S runners @@ -387,7 +375,7 @@ jobs: gpu: T4 runner: CUDA-Windows-x64 cuda_version: "11.8.0" - torch_version: "2.2.0" + torch_version: "2.3.1" pypi_index: "https://download.pytorch.org/whl/cu118" - os: windows-2025 arch: x86_64 @@ -401,12 +389,14 @@ jobs: gpu: T4 runner: CUDA-Windows-x64 cuda_version: "11.8.0" - torch_version: "2.7.1" + torch_version: "2.7.1" # Note: this is the last PyTorch release supporting CUDA 11.8. pypi_index: "https://download.pytorch.org/whl/cu118" exclude: # Our current T4 Windows runner has a driver too old (471.11) # and cannot support CUDA 12+. Skip for now. + - os: windows-2025 + cuda_version: "12.9.1" - os: windows-2025 cuda_version: "12.8.1" - os: windows-2025 @@ -438,15 +428,9 @@ jobs: - name: Install dependencies run: | - pip install torch==${{ matrix.torch_version }} --index-url ${{ matrix.pypi_index }} + pip install --pre torch~=${{ matrix.torch_version }}.dev0 --index-url ${{ matrix.pypi_index }} pip install -e ".[test]" pip install pytest-cov - - # We need to downgrade to numpy<2 for torch<2.3 compatibility. - - name: Downgrade NumPy - if: startsWith(matrix.torch_version, '2.2.') - run: pip install "numpy<2" - - name: Show installed packages run: pip list diff --git a/CMakeLists.txt b/CMakeLists.txt index 770b4ba30..d9529b0d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,11 +28,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) set(HIP_FILES csrc/ops.hip csrc/kernels.hip) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) +set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -64,10 +65,19 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps") set(BUILD_CUDA OFF) set(BUILD_HIP OFF) set(BUILD_MPS ON) +elseif(${COMPUTE_BACKEND} STREQUAL "xpu") + if(APPLE) + message(FATAL_ERROR "XPU is not supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) + set(BUILD_XPU ON) else() set(BUILD_CUDA OFF) set(BUILD_HIP OFF) set(BUILD_MPS OFF) + set(BUILD_XPU OFF) endif() @@ -217,6 +227,15 @@ elseif(BUILD_MPS) COMMENT "Compiling Metal kernels" VERBATIM) add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +elseif(BUILD_XPU) + list(APPEND SRC_FILES ${XPU_FILES}) + string(APPEND BNB_OUTPUT_NAME "_xpu") + add_compile_definitions(BUILD_XPU) + set(CMAKE_C_COMPILER icx) + set(CMAKE_CXX_COMPILER icpx) + if(WIN32) + set(CMAKE_CXX_COMPILER icx) + endif() else() string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) @@ -285,6 +304,15 @@ if(BUILD_MPS) add_dependencies(bitsandbytes metallib) target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") endif() +if(BUILD_XPU) + set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'") + set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;") + + set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20) + target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS}) + target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS}) + +endif() if(WIN32) set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..00bdaa214 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include CMakeLists.txt +graft csrc +graft include diff --git a/README.md b/README.md index c6c5ff25b..732baea69 100644 --- a/README.md +++ b/README.md @@ -20,13 +20,13 @@ The library includes quantization primitives for 8-bit & 4-bit operations, throu bitsandbytes has the following minimum requirements for all platforms: * Python 3.9+ -* [PyTorch](https://pytorch.org/get-started/locally/) 2.2+ +* [PyTorch](https://pytorch.org/get-started/locally/) 2.3+ * _Note: While we aim to provide wide backwards compatibility, we recommend using the latest version of PyTorch for the best experience._ #### Accelerator support: Note: this table reflects the status of the current development branch. For the latest stable release, see the -[document in the v0.46.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.46.0/README.md#accelerator-support). +[document in the 0.47.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.47.0/README.md#accelerator-support). ##### Legend: @@ -61,7 +61,7 @@ bitsandbytes has the following minimum requirements for all platforms: 🟩 NVIDIA GPU
cuda - SM50+ minimum
SM75+ recommended + SM60+ minimum
SM75+ recommended āœ… āœ… āœ… @@ -71,11 +71,11 @@ bitsandbytes has the following minimum requirements for all platforms: 🟄 AMD GPU
cuda CDNA: gfx90a, gfx942
- RDNA: gfx1100, gfx1200 + RDNA: gfx1100 - 🚧 - 🚧 - 🚧 + āœ… + ć€°ļø + āœ… @@ -85,14 +85,14 @@ bitsandbytes has the following minimum requirements for all platforms: Arc A-Series (Alchemist)
Arc B-Series (Battlemage) - 🚧 - 🚧 - 🚧 + āœ… + āœ… + ć€°ļø 🟪 Intel Gaudi
hpu - Gaudi1, Gaudi2, Gaudi3 + Gaudi2, Gaudi3 āœ… ć€°ļø āŒ @@ -108,7 +108,7 @@ bitsandbytes has the following minimum requirements for all platforms: 🟩 NVIDIA GPU
cuda - SM75, SM80, SM90, SM100 + SM75+ āœ… āœ… āœ… @@ -127,7 +127,7 @@ bitsandbytes has the following minimum requirements for all platforms: 🟩 NVIDIA GPU
cuda - SM50+ minimum
SM75+ recommended + SM60+ minimum
SM75+ recommended āœ… āœ… āœ… @@ -139,9 +139,9 @@ bitsandbytes has the following minimum requirements for all platforms: Arc A-Series (Alchemist)
Arc B-Series (Battlemage) - 🚧 - 🚧 - 🚧 + āœ… + āœ… + ć€°ļø šŸŽ macOS 14+ @@ -173,7 +173,9 @@ bitsandbytes has the following minimum requirements for all platforms: ## :heart: Sponsors The continued maintenance and development of `bitsandbytes` is made possible thanks to the generous support of our sponsors. Their contributions help ensure that we can keep improving the project and delivering valuable updates to the community. -Hugging Face +Hugging Face +  +Intel ## License `bitsandbytes` is MIT licensed. diff --git a/_typos.toml b/_typos.toml index 955c6cb79..fce018f81 100644 --- a/_typos.toml +++ b/_typos.toml @@ -1,4 +1,11 @@ [files] +# Skip these files in typo checks +extend-exclude = [ + "csrc/xpu_ops.h", + "csrc/xpu_ops.cpp", + "csrc/xpu_kernels.h", + "csrc/xpu_kernels.cpp" +] [default] extend-ignore-re = [ diff --git a/benchmarking/inference_benchmark.py b/benchmarking/inference_benchmark.py index 61ac570f2..72ee8cfae 100644 --- a/benchmarking/inference_benchmark.py +++ b/benchmarking/inference_benchmark.py @@ -21,6 +21,9 @@ --batches BATCHES [BATCHES ...] --input-length INPUT_LENGTH --out-dir OUT_DIR + --iterations ITERATIONS + --warmup-runs WARMUP_RUNS + --output-length OUTPUT_LENGTH """ import argparse @@ -30,6 +33,9 @@ from optimum_benchmark.logging_utils import setup_logging import torch +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True + BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8 WEIGHTS_CONFIGS = { @@ -73,9 +79,8 @@ }, } -if __name__ == "__main__": - setup_logging(level="INFO") +def parse_args(): parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool") parser.add_argument("model_id", type=str, help="The model checkpoint to use.") @@ -98,37 +103,73 @@ parser.add_argument("--out-dir", type=str, default="reports") - args = parser.parse_args() + parser.add_argument("--iterations", type=int, default=10, help="Number of iterations for each benchmark run") + parser.add_argument( + "--warmup-runs", type=int, default=10, help="Number of warmup runs to discard before measurement" + ) + parser.add_argument( + "--output-length", + type=int, + default=64, + help="If set, `max_new_tokens` and `min_new_tokens` will be set to this value.", + ) + + return parser.parse_args() + + +def run_benchmark(args, config, batch_size): + launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="warn", start_method="spawn") + scenario_config = InferenceConfig( + latency=True, + memory=True, + input_shapes={"batch_size": batch_size, "sequence_length": args.input_length}, + iterations=args.iterations, + warmup_runs=args.warmup_runs, + # set duration to 0 to disable the duration-based stopping criterion + # this is IMPORTANT to ensure that all benchmarks run the same number of operations, regardless of hardware speed/bottlenecks + duration=0, + # for consistent results, set a fixed min and max for output tokens + generate_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length}, + forward_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length}, + ) + + backend_config = PyTorchConfig( + device="cuda", + device_ids="0", + device_map="auto", + no_weights=False, + model=args.model_id, + **WEIGHTS_CONFIGS[config], + ) + + test_name = ( + f"benchmark-{config}" + f"-bsz-{batch_size}" + f"-isz-{args.input_length}" + f"-osz-{args.output_length}" + f"-iter-{args.iterations}" + f"-wrmup-{args.warmup_runs}" + ) + benchmark_config = BenchmarkConfig( + name=test_name, + scenario=scenario_config, + launcher=launcher_config, + backend=backend_config, + ) + + out_path = out_dir / (test_name + ".json") + print(f"[{test_name}] Starting:") + benchmark_report = Benchmark.launch(benchmark_config) + benchmark_report.save_json(out_path) + + +if __name__ == "__main__": + setup_logging(level="INFO") + args = parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) for batch_size in args.batches: - print(f"Benchmarking batch size: {batch_size}") for config in args.configs: - launcher_config = ProcessConfig(device_isolation=True, start_method="spawn") - scenario_config = InferenceConfig( - latency=True, - memory=True, - input_shapes={"batch_size": batch_size, "sequence_length": args.input_length}, - ) - backend_config = PyTorchConfig( - device="cuda", - device_ids="0", - device_map="auto", - no_weights=False, - model=args.model_id, - **WEIGHTS_CONFIGS[config], - ) - benchmark_config = BenchmarkConfig( - name=f"benchmark-{config}-bsz{batch_size}", - scenario=scenario_config, - launcher=launcher_config, - backend=backend_config, - ) - - out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json" - - benchmark_report = Benchmark.launch(benchmark_config) - benchmark_report.log() - benchmark_report.save_json(out_path) + run_benchmark(args, config, batch_size) diff --git a/benchmarking/xpu/inference_benchmark.py b/benchmarking/xpu/inference_benchmark.py new file mode 100644 index 000000000..055abed2e --- /dev/null +++ b/benchmarking/xpu/inference_benchmark.py @@ -0,0 +1,147 @@ +import argparse +import time + +# import intel_extension_for_pytorch as ipex +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +MAX_NEW_TOKENS = 256 + +get_time = time.time + +system_prompt = "You are a helpful assistant" +user_prompt = """Summarize this text please: + +```Tell me, O muse, of that ingenious hero who travelled far and wide after he had sacked the famous town of Troy. Many cities did he visit, and many were the nations with whose manners and customs he was acquainted; moreover he suffered much by sea while trying to save his own life and bring his men safely home; but do what he might he could not save his men, for they perished through their own sheer folly in eating the cattle of the Sun-god Hyperion; so the god prevented them from ever reaching home. Tell me, too, about all these things, O daughter of Jove, from whatsoever source you may know them. + +So now all who escaped death in battle or by shipwreck had got safely home except Ulysses, and he, though he was longing to return to his wife and country, was detained by the goddess Calypso, who had got him into a large cave and wanted to marry him. But as years went by, there came a time when the gods settled that he should go back to Ithaca; even then, however, when he was among his own people, his troubles were not yet over; nevertheless all the gods had now begun to pity him except Neptune, who still persecuted him without ceasing and would not let him get home. + +Now Neptune had gone off to the Ethiopians, who are at the world's end, and lie in two halves, the one looking West and the other East. He had gone there to accept a hecatomb of sheep and oxen, and was enjoying himself at his festival; but the other gods met in the house of Olympian Jove, and the sire of gods and men spoke first. At that moment he was thinking of Aegisthus, who had been killed by Agamemnon's son Orestes; so he said to the other gods: + +"See now, how men lay blame upon us gods for what is after all nothing but their own folly. Look at Aegisthus; he must needs make love to Agamemnon's wife unrighteously and then kill Agamemnon, though he knew it would be the death of him; for I sent Mercury to warn him not to do either of these things, inasmuch as Orestes would be sure to take his revenge when he grew up and wanted to return home. Mercury told him this in all good will but he would not listen, and now he has paid for everything in full." + +Then Minerva said, "Father, son of Saturn, King of kings, it served Aegisthus right, and so it would any one else who does as he did; but Aegisthus is neither here nor there; it is for Ulysses that my heart bleeds, when I think of his sufferings in that lonely sea-girt island, far away, poor man, from all his friends. It is an island covered with forest, in the very middle of the sea, and a goddess lives there, daughter of the magician Atlas, who looks after the bottom of the ocean, and carries the great columns that keep heaven and earth asunder. This daughter of Atlas has got hold of poor unhappy Ulysses, and keeps trying by every kind of blandishment to make him forget his home, so that he is tired of life, and thinks of nothing but how he may once more see the smoke of his own chimneys. You, sir, take no heed of this, and yet when Ulysses was before Troy did he not propitiate you with many a burnt sacrifice? Why then should you keep on being so angry with him?" + +And Jove said, "My child, what are you talking about? How can I forget Ulysses than whom there is no more capable man on earth, nor more liberal in his offerings to the immortal gods that live in heaven? Bear in mind, however, that Neptune is still furious with Ulysses for having blinded an eye of Polyphemus king of the Cyclopes. Polyphemus is son to Neptune by the nymph Thoosa, daughter to the sea-king Phorcys; therefore though he will not kill Ulysses outright, he torments him by preventing him from getting home. Still, let us lay our heads together and see how we can help him to return; Neptune will then be pacified, for if we are all of a mind he can hardly stand out against us."```""" + +prompt = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, +] + + +def get_inputs(tokenizer): + inputs = tokenizer.apply_chat_template( + prompt, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + ) + return inputs + + +def get_streamer(tokenizer): + streamer = Streamer(tokenizer) + # streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + return streamer + + +class Streamer: + def __init__(self, tokenizer, print_median=False): + self.times = [] + self.print_median = print_median + self.tokenizer = tokenizer + + def put(self, t): + self.times.append(get_time()) + if len(self.times) > 1: + print(f"Token latency: {1000 * (self.times[-1] - self.times[-2]):.1f} ms") + + if len(self.times) % 10 == 3 and self.print_median: + ts = np.array(self.times) + diff = ts[1:] - ts[:-1] + # print("Token latency:", 1000 * diff, "ms") + print("Token latency median:", np.median(1000 * diff), "ms") + + def print_report(self): + times = np.array(self.times) + diff = times[1:] - times[:-1] + print(f"Median latency: {round(np.median(diff) * 1000, 2)}ms") + percentiles = [10, 25, 50, 75, 90] + print( + "Latency percentiles", + {p: round(1000 * float(np.percentile(diff, p)), 1) for p in percentiles}, + ) + + def end(self, *args): + pass + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Run inference benchmark for LLM models") + parser.add_argument( + "--device", + type=str, + default="xpu", + help="Device to run inference on (e.g., xpu, cuda, cpu)", + ) + parser.add_argument( + "--model-id", + type=str, + default="unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", + help="Model ID from Hugging Face or local path", + ) + parser.add_argument( + "--attn", + type=str, + default="eager", + choices=["eager", "flash_attention", "sdpa"], + help="Attention implementation to use", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + + device = args.device + model_id = args.model_id + + print(f"Running inference on {device} with model {model_id}") + print(f"Using attention implementation: {args.attn}") + + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=args.attn) + + inputs = get_inputs(tokenizer) + streamer = get_streamer(tokenizer) + + inputs = inputs.to(device) + model = model.to(device) + + generation_config = GenerationConfig( + use_cache=True, + forced_eos_token_id=1, + eos_token_id=1, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + ) + + outputs = model.generate( + **inputs, + streamer=streamer, + generation_config=generation_config, + ) + + # Print the final outputs (including the input prompt) + output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + + print(r"\Output (including prompt):") + print("-" * 40) + print(output_text) + print("-" * 40) + print(f"Peak memory usage: {torch.xpu.max_memory_allocated() / 1024**2:.0f}MB") + + streamer.print_report() diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 516afa51f..d58b7b441 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -38,7 +38,6 @@ if hasattr(torch, "xpu") and torch.xpu.is_available(): from .backends.xpu import ops as xpu_ops - if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"): # In case not automatically imported import habana_frameworks.torch @@ -76,4 +75,4 @@ def _import_backends(): "optim.optimizer.MockArgs": False, } -__version__ = "0.47.0.dev0" +__version__ = "0.48.0.dev0" diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index a260852f5..532fe7afa 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -4,8 +4,6 @@ import torch -from .cextension import ipex_cpu, ipex_xpu - _IS_TORCH_GTE_24 = False if hasattr(torch.library, "register_fake"): @@ -331,20 +329,105 @@ def _( torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") -if ipex_cpu or ipex_xpu: - # Register the dequantize_nf4_ipex implementation - torch.library.define( - "bitsandbytes::dequantize_nf4_ipex", - "(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor", +torch.library.define( + "bitsandbytes::optimizer_update_32bit", + "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()", +) + + +@register_fake("bitsandbytes::optimizer_update_32bit") +def _( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + weight_decay: float, + step: int, + lr: float, + gnorm_scale: float, + skip_zeros=False, +) -> None: + torch._check( + g.numel() == p.numel(), + lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}", + ) + compute_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + torch._check( + g.dtype in compute_dtypes, + lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}", + ) + torch._check( + g.dtype == p.dtype, + lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}", ) - @register_fake("bitsandbytes::dequantize_nf4_ipex") - def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - shape: Sequence[int], - dtype: torch.dtype, - ) -> torch.Tensor: - torch._check_is_size(blocksize) - return torch.empty(shape, dtype=dtype, device=A.device) + +torch.library.define( + "bitsandbytes::optimizer_update_8bit_blockwise", + "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()", +) + + +@register_fake("bitsandbytes::optimizer_update_8bit_blockwise") +def _( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float, + gnorm_scale: float, + skip_zeros=False, +) -> None: + torch._check( + g.numel() == p.numel(), + lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}", + ) + compute_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + torch._check( + g.dtype in compute_dtypes, + lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}", + ) + torch._check( + g.dtype == p.dtype, + lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}", + ) + torch._check( + state1.dtype == torch.uint8, + lambda: f"state1 must be uint8, got {state1.dtype}", + ) + torch._check( + qmap1.dtype == absmax1.dtype == torch.float32, + lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}", + ) + if state2 is not None: + torch._check( + state2.dtype == torch.uint8, + lambda: f"state2 must be uint8, got {state2.dtype}", + ) + torch._check( + qmap2.dtype == absmax2.dtype == torch.float32, + lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", + ) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 80fc86861..ece18caa3 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -8,7 +8,6 @@ from typing_extensions import deprecated import bitsandbytes.functional as F -from bitsandbytes.functional import ipex_cpu, ipex_xpu # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py @@ -85,11 +84,7 @@ def get_inverse_transform_indices( return permuted_tile_indices -# torch.compiler.is_compiling() is available only in torch >= 2.3 -if hasattr(torch.compiler, "is_compiling"): - _is_compiling = torch.compiler.is_compiling -else: - _is_compiling = torch._dynamo.is_compiling +_is_compiling = torch.compiler.is_compiling @deprecated( @@ -320,8 +315,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) output = torch.nn.functional.linear(A, CB, bias) - # to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu] - state.idx = False ctx.state = state ctx.dtype_A = A.dtype ctx.grad_shape = A.shape @@ -426,7 +419,7 @@ def matmul( state.threshold = threshold # MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU if state.is_training: - if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu): + if A.device.type in ("cpu", "xpu"): return MatMul8bitFp.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state) @@ -440,17 +433,6 @@ def matmul_4bit( ): assert quant_state is not None - if A.device.type in ("cpu", "xpu") and A.requires_grad == False: - if getattr(quant_state, "ipex", False): - # IPEX CPU will change weight to 4D so don't need transpose - B = B.t() if B.dim() == 2 else B - out = F.gemv_4bit(A, B, out, state=quant_state) - if bias is not None: - out += bias - return out - else: - return MatMul4Bit.apply(A, B, out, bias, quant_state) - if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: warn( diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 5f009ea40..e295cc2a3 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -1,13 +1,14 @@ -from collections.abc import Sequence import ctypes as ct +import logging import torch from bitsandbytes.functional import get_ptr from ..._ops import register_kernel -from ...cextension import lib -from ..utils import ipex_cpu +from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib + +logger = logging.getLogger(__name__) # torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+. # However, we can overflow if we use this without AVX512_VNNI support. @@ -24,97 +25,77 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0]) -@register_kernel("bitsandbytes::quantize_blockwise", "cpu") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - - n = A.numel() - - # Only FP32 has c++ kernrl - if A.dtype == torch.float32: - blocks = -(n // -blocksize) - - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(n), - ) - else: - rem = n % blocksize - has_rem = rem > 0 - blocks = n // blocksize + has_rem - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - A_reshaped = A.reshape(n) - A_com = A_reshaped[: n - rem] - A_com_reshaped = A_com.reshape(n // blocksize, blocksize) - absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] - scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) - scaled_A = scaled_A.reshape(-1) - if has_rem: - absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() - scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) - scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) - - diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) - out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cpu") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - torch._check_is_size(blocksize) - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - - # Only FP32 has c++ kernrl - if dtype == torch.float32: - out = torch.empty_like(A, dtype=dtype) - - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), - ) - else: - out = code[A.reshape(-1).int()] - blocks = out.shape[-1] // blocksize - res = out.shape[-1] % blocksize - if res != 0: - out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0) - out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) - out = out[: blocks * blocksize + res] - out = out.reshape(A.shape) - - return out - - -if ipex_cpu: - from bitsandbytes.utils import _reverse_4bit_compress_format - - @register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu") +if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): + + @register_kernel("bitsandbytes::quantize_blockwise", "cpu") + def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + + n = A.numel() + + # Only FP32 has c++ kernrl + if A.dtype == torch.float32: + blocks = -(n // -blocksize) + + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(n), + ) + else: + rem = n % blocksize + has_rem = rem > 0 + blocks = n // blocksize + has_rem + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + + diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) + out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) + + return out, absmax + + @register_kernel("bitsandbytes::dequantize_blockwise", "cpu") def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - shape: Sequence[int], - dtype: torch.dtype, + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype ) -> torch.Tensor: - ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2) - A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1) - return torch.ops.bitsandbytes.dequantize_4bit.default( - A, - absmax, - blocksize, - "nf4", - shape, - dtype, - ) + torch._check_is_size(blocksize) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + + # Only FP32 has c++ kernrl + if dtype == torch.float32: + out = torch.empty_like(A, dtype=dtype) + + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + else: + out = code[A.reshape(-1).int()] + blocks = out.shape[-1] // blocksize + res = out.shape[-1] % blocksize + if res != 0: + out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0) + out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) + out = out[: blocks * blocksize + res] + out = out.reshape(A.shape) + + return out diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 13359bbd8..30cad3e34 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -538,3 +538,229 @@ def _gemv_4bit_impl( ct.c_int32(blocksize), stream, ) + + +"""C FUNCTIONS FOR OPTIMIZERS""" +str2optimizer32bit = { + "adam": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + lib.cadam32bit_grad_bf16, + ), + "momentum": ( + lib.cmomentum32bit_grad_32, + lib.cmomentum32bit_grad_16, + ), + "rmsprop": ( + lib.crmsprop32bit_grad_32, + lib.crmsprop32bit_grad_16, + ), + "lion": ( + lib.clion32bit_grad_fp32, + lib.clion32bit_grad_fp16, + lib.clion32bit_grad_bf16, + ), + "adagrad": ( + lib.cadagrad32bit_grad_32, + lib.cadagrad32bit_grad_16, + ), + "lamb": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + lib.cadam32bit_grad_bf16, + ), + "ademamix": ( + lib.cademamix32bit_grad_fp32, + lib.cademamix32bit_grad_fp16, + lib.cademamix32bit_grad_bf16, + ), +} + +str2optimizer8bit_blockwise = { + "adam": ( + lib.cadam_8bit_blockwise_grad_fp32, + lib.cadam_8bit_blockwise_grad_fp16, + lib.cadam_8bit_blockwise_grad_bf16, + ), + "momentum": ( + lib.cmomentum_8bit_blockwise_grad_fp32, + lib.cmomentum_8bit_blockwise_grad_fp16, + lib.cmomentum_8bit_blockwise_grad_bf16, + ), + "rmsprop": ( + lib.crmsprop_8bit_blockwise_grad_fp32, + lib.crmsprop_8bit_blockwise_grad_fp16, + lib.crmsprop_8bit_blockwise_grad_bf16, + ), + "lion": ( + lib.clion_8bit_blockwise_grad_fp32, + lib.clion_8bit_blockwise_grad_fp16, + lib.clion_8bit_blockwise_grad_bf16, + ), + "adagrad": ( + lib.cadagrad_8bit_blockwise_grad_fp32, + lib.cadagrad_8bit_blockwise_grad_fp16, + lib.cadagrad_8bit_blockwise_grad_bf16, + ), + "ademamix": ( + lib.cademamix_8bit_blockwise_grad_fp32, + lib.cademamix_8bit_blockwise_grad_fp16, + lib.cademamix_8bit_blockwise_grad_bf16, + ), +} + + +def _optimizer_update_32bit_impl( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + weight_decay: float, + step: int, + lr: float, + gnorm_scale: float, + skip_zeros=False, +) -> None: + optim_fns = str2optimizer32bit.get(optimizer_name, None) + if optim_fns is None: + raise ValueError( + f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}" + ) + if g.dtype == torch.float32: + optim_func = optim_fns[0] + elif g.dtype == torch.float16: + optim_func = optim_fns[1] + elif g.dtype == torch.bfloat16 and len(optim_fns) == 3: + optim_func = optim_fns[2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) + + with _cuda_device_of(g): + optim_func( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(beta3), + ct.c_float(alpha), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) + + +def _optimizer_update_8bit_blockwise_impl( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float, + gnorm_scale: float, + skip_zeros=False, +) -> None: + # torch._check( + # g.numel() == p.numel(), + # lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}", + # ) + # compute_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + # torch._check( + # g.dtype in compute_dtypes, + # lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}", + # ) + # torch._check( + # g.dtype == p.dtype, + # lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}", + # ) + # torch._check( + # state1.dtype == torch.uint8, + # lambda: f"state1 must be uint8, got {state1.dtype}", + # ) + # torch._check( + # qmap1.dtype == absmax1.dtype == torch.float32, + # lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}", + # ) + # if state2 is not None: + # torch._check( + # state2.dtype == torch.uint8, + # lambda: f"state2 must be uint8, got {state2.dtype}", + # ) + # torch._check( + # qmap2.dtype == absmax2.dtype == torch.float32, + # lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", + # ) + optimizer_fns = str2optimizer8bit_blockwise.get(optimizer_name) + if optimizer_fns is None: + raise ValueError( + f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}" + ) + + if g.dtype == torch.float32: + optimizer_fn = optimizer_fns[0] + elif g.dtype == torch.float16: + optimizer_fn = optimizer_fns[1] + elif g.dtype == torch.bfloat16: + optimizer_fn = optimizer_fns[2] + else: + raise ValueError( + f"Unsupported gradient dtype: {g.dtype}. Supported dtypes: torch.float32, torch.float16, torch.bfloat16" + ) + + with _cuda_device_of(g): + optimizer_fn( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(beta3), + ct.c_float(alpha), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) + + +register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl) +register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl) diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index ce5926979..067347d47 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from math import prod +from math import prod, sqrt from typing import Optional import torch @@ -301,3 +301,278 @@ def _( B_dq, bias=None, ) + + +MOMENTUM = 0 +RMSPROP = 1 +ADAGRAD = 2 +ADAM = 3 +# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels +LION = 4 +ADEMAMIX = 5 + +name2optimizer_id = { + "momentum": MOMENTUM, + "rmsprop": RMSPROP, + "adagrad": ADAGRAD, + "adam": ADAM, + "lion": LION, + "ademamix": ADEMAMIX, +} + + +@torch.compile +def _optimizer_precondition_32bit( + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: torch.Tensor, + beta1: float, + beta2: float, + eps: float, + weight_decay: float, + step: int, + lr: float, + gnorm_scale: float, + optimizer_id: int, +): + """Preprocessing optimizer, computing update norm""" + + g_vals = gnorm_scale * g + + if optimizer_id == 3: # ADAM + correction1 = 1.0 / (1.0 - beta1**step) + correction2 = 1.0 / (1.0 - beta2**step) + + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals + s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals + + s1_vals = s1_vals * correction1 + s2_vals = s2_vals * correction2 + + update_vals = s1_vals / (torch.sqrt(s2_vals) + eps) + update_norm = update_vals * update_vals + + elif optimizer_id == 5: # ADEMAMIX + update_norm = state1 + + elif optimizer_id == 0: # MOMENTUM + if step == 1: + s1_vals = g_vals + else: + s1_vals = state1 * beta1 + g_vals + update_norm = s1_vals * s1_vals + + elif optimizer_id == 4: # LION + s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals + update_norm = s1_vals + + elif optimizer_id == 1: # RMSPROP + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals + update_vals = g_vals / (torch.sqrt(s1_vals) + eps) + update_norm = update_vals * update_vals + + elif optimizer_id == 2: # ADAGRAD + s1_vals = state1 + g_vals * g_vals + update_vals = g_vals / (torch.sqrt(s1_vals) + eps) + update_norm = update_vals * update_vals + + total_norm = torch.sum(update_norm) + unorm_vec.add_(total_norm) + + +@torch.compile +def _optimizer_update_32bit( + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + weight_decay: float, + step: int, + lr: float, + gnorm_scale: float, + optimizer_id: int, +): + """Unified optimizer update kernel""" + + p_vals = p.float() + g_vals = (gnorm_scale * g).float() + if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0: + g_vals = g_vals + p_vals * weight_decay + + update_scale = 1.0 + if max_unorm > 0.0: + current_unorm = torch.sqrt(unorm_vec) + if optimizer_id in [0, 1, 2, 4]: # 1-state optimizers + if current_unorm > max_unorm * param_norm + eps: + update_scale = (max_unorm * param_norm + eps) / current_unorm + else: # 2-state optimizers + if current_unorm > max_unorm * param_norm: + update_scale = (max_unorm * param_norm) / current_unorm + + if optimizer_id == 3: # ADAM + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals + s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals + + correction1 = 1.0 - beta1**step + correction2 = sqrt(1.0 - beta2**step) + step_size = -lr * correction2 / correction1 + + if weight_decay > 0.0: + p_vals = p_vals * (1.0 - lr * weight_decay) + + update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2)) + p_vals = p_vals + update_val + + state1.copy_(s1_vals) + state2.copy_(s2_vals) + + elif optimizer_id == 5: # ADEMAMIX + s1_vals = state1[0] + s3_vals = state1[1] + s2_vals = state2 + + m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals + m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals + nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals + + correction1 = 1.0 - beta1**step + correction2 = sqrt(1.0 - beta2**step) + + if weight_decay > 0.0: + p_vals = p_vals * (1.0 - lr * weight_decay) + + mixed_momentum = (m1 / correction1) + (alpha * m2) + adaptive_term = (torch.sqrt(nu) / correction2) + eps + p_vals = p_vals - lr * (mixed_momentum / adaptive_term) + + state1[0].copy_(m1) + state1[1].copy_(m2) + state2.copy_(nu) + + elif optimizer_id == 0: # MOMENTUM + if step == 1: + s1_vals = g_vals + else: + s1_vals = state1 * beta1 + g_vals + + update_val = update_scale * (-lr * s1_vals) + p_vals = p_vals + update_val + + state1.copy_(s1_vals) + + elif optimizer_id == 4: # LION + momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals + update_val = update_scale * lr * torch.sign(momentum_update) + p_vals = p_vals - update_val + + s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals + state1.copy_(s1_vals) + + elif optimizer_id == 1: # RMSPROP + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals + update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps) + p_vals = p_vals - update_val + + state1.copy_(s1_vals) + + elif optimizer_id == 2: # ADAGRAD + s1_vals = state1 + g_vals * g_vals + update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps) + p_vals = p_vals - update_val + + state1.copy_(s1_vals) + + p.copy_(p_vals) + + +@register_kernel("bitsandbytes::optimizer_update_32bit", "default") +def _( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + weight_decay: float, + step: int, + lr: float, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + """ + 32-bit optimizer implemented by PyTorch with @torch.compile + """ + if skip_zeros: + raise NotImplementedError("skip_zeros is not supported yet") + + optimizer_id = name2optimizer_id[optimizer_name] + + if optimizer_name == "lion": + _optimizer_update_32bit( + g, + p, + state1, + state2, + unorm_vec, + max_unorm, + param_norm, + beta1, + beta2, + beta3, + alpha, + eps, + weight_decay, + step, + lr, + gnorm_scale, + optimizer_id, + ) + + if max_unorm > 0.0: + unorm_vec.zero_() + _optimizer_precondition_32bit( + g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id + ) + else: + if max_unorm > 0.0: + unorm_vec.zero_() + _optimizer_precondition_32bit( + g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id + ) + + _optimizer_update_32bit( + g, + p, + state1, + state2, + unorm_vec, + max_unorm, + param_norm, + beta1, + beta2, + beta3, + alpha, + eps, + weight_decay, + step, + lr, + gnorm_scale, + optimizer_id, + ) diff --git a/bitsandbytes/backends/hpu/ops.py b/bitsandbytes/backends/hpu/ops.py index 4c43a3cb7..9ecd63e0b 100644 --- a/bitsandbytes/backends/hpu/ops.py +++ b/bitsandbytes/backends/hpu/ops.py @@ -3,12 +3,19 @@ import torch -from bitsandbytes.utils import _reverse_4bit_compress_format - from ..._ops import register_kernel from ..utils import GAUDI_SW_VER +# convert btw standard 4-bit compression format and ipex compression format +# needed for backward compatibility with older versions of gaudi sw +def _reverse_4bit_compress_format(weight: torch.Tensor): + out_1 = (weight & 0xF0) >> 4 + out_2 = (weight & 0xF) << 4 + out = out_1 | out_2 + return out + + @register_kernel("bitsandbytes::dequantize_4bit", "hpu") def _( A: torch.Tensor, diff --git a/bitsandbytes/backends/triton/triton_kernels.py b/bitsandbytes/backends/triton/kernels_4bit.py similarity index 78% rename from bitsandbytes/backends/triton/triton_kernels.py rename to bitsandbytes/backends/triton/kernels_4bit.py index 03ffa187d..0e94f49e8 100644 --- a/bitsandbytes/backends/triton/triton_kernels.py +++ b/bitsandbytes/backends/triton/kernels_4bit.py @@ -4,167 +4,6 @@ import triton.language as tl -# @triton.autotune( -# configs=[ -# # triton.Config({'SPLIT_SIZE': 64}), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128}), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), -# triton.Config({"SPLIT_SIZE": 256}), -# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), -# triton.Config({"SPLIT_SIZE": 512}), -# # triton.Config({'SPLIT_SIZE': 1024}), -# ], -# key=["num_paired_elements", "QUANT_BLOCK"], -# ) -@triton.jit -def dequant_8bit_kernel( - a_ptr, - c_ptr, - quant_ptr, - absmax_ptr, - num_paired_elements, - QUANT_BLOCK: tl.constexpr, - SPLIT_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * SPLIT_SIZE - offsets = block_start + tl.arange(0, SPLIT_SIZE) - mask = offsets < num_paired_elements - - a = tl.load(a_ptr + offsets, mask) - a = a.to(tl.uint8) - - # apply conversion - scaled_int8 = tl.load(quant_ptr + a, mask) - - abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK - abs_offsets = offsets // QUANT_BLOCK - mask_blocked = offsets < abs_blocks_lim - - absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked) - # apply scales - out_dq = scaled_int8 * absmax - - offs = block_start + tl.arange(0, SPLIT_SIZE) - mask = offs < num_paired_elements - tl.store(c_ptr + offs, out_dq, mask) - - -def dequant_int8_blockwise( - A_nf4: torch.Tensor, - quant_state_code: torch.Tensor, - absmax: torch.Tensor, - out: torch.Tensor, - quant_blocksize: int = 64, -): - number_of_paired_elements = A_nf4.numel() - - SPLIT_SIZE = 256 - # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),) - grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) - dequant_8bit_kernel[grid]( - A_nf4, - out, - quant_state_code, - absmax, - number_of_paired_elements, - quant_blocksize, - SPLIT_SIZE, - ) - return out - - -# @triton.autotune( -# configs=[ -# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), -# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), -# triton.Config({"SPLIT_NUM_BLOCKS": 1}), -# triton.Config({"SPLIT_NUM_BLOCKS": 2}), -# ], -# key=["n_elements"], -# ) -@triton.jit -def quantize_blockwise_kernel( - A_ptr, - code_ptr, - absmax_ptr, - out_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - CODE_SIZE: tl.constexpr, - SPLIT_NUM_BLOCKS: tl.constexpr, -): - block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS - thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) - - offsets = block_start_idx * BLOCK_SIZE + thread_idx - mask = offsets < n_elements - - A = tl.load(A_ptr + offsets, mask=mask, other=0.0) - - # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) - A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE)) - - # Calculating absamax for each block - absmax = tl.max(tl.abs(A_reshaped), axis=1) - tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax) - - A_normalized = A_reshaped / absmax[:, None] - A_normalized = tl.clamp(A_normalized, -1.0, 1.0) - - lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32) - upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) - - for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter - pivot = (lower_pivot + upper_pivot) // 2 - val = tl.load(code_ptr + pivot) - is_higher = A_normalized > val # code[pivot] - lower_pivot = tl.where(is_higher, pivot, lower_pivot) - upper_pivot = tl.where(is_higher, upper_pivot, pivot) - - # Choose closest level - lower_val = tl.load(code_ptr + lower_pivot) - upper_val = tl.load(code_ptr + upper_pivot) - lower_dist = tl.abs(A_normalized - lower_val) - upper_dist = tl.abs(A_normalized - upper_val) - quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) - - # too slow approach - # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :]) - # quantized = tl.argmin(diff, axis=2).to(tl.uint8) - - quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) - tl.store(out_ptr + offsets, quantized_flat, mask=mask) - - -def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out): - n = A.numel() - - split_num_blocks = 1 - grid = (triton.cdiv(blocks, split_num_blocks),) - # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),) - quantize_blockwise_kernel[grid]( - A_ptr=A, - code_ptr=code, - absmax_ptr=absmax, - out_ptr=quantized_out, - n_elements=n, - BLOCK_SIZE=blocksize, - CODE_SIZE=code.numel(), - SPLIT_NUM_BLOCKS=split_num_blocks, - ) - - return quantized_out, absmax - - # Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4 # @triton.autotune( # configs=[ @@ -587,7 +426,7 @@ def dequant_nf4_kernel( tl.store(c_ptr + offs, out_dq, mask) -def _dequantize_4bit_impl( +def dequantize_4bit_impl( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, @@ -611,7 +450,7 @@ def _dequantize_4bit_impl( dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) -def _dequantize_4bit_impl_passing_code( +def dequantize_4bit_impl_passing_code( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, diff --git a/bitsandbytes/backends/triton/kernels_8bit_quant.py b/bitsandbytes/backends/triton/kernels_8bit_quant.py new file mode 100644 index 000000000..c0a5a21ef --- /dev/null +++ b/bitsandbytes/backends/triton/kernels_8bit_quant.py @@ -0,0 +1,195 @@ +import torch + +import triton +import triton.language as tl + + +# @triton.autotune( +# configs=[ +# # triton.Config({'SPLIT_SIZE': 64}), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128}), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_SIZE": 256}), +# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# triton.Config({"SPLIT_SIZE": 512}), +# # triton.Config({'SPLIT_SIZE': 1024}), +# ], +# key=["num_paired_elements", "QUANT_BLOCK"], +# ) +@triton.jit +def dequant_8bit_kernel( + a_ptr, + out_ptr, + code_ptr, + absmax_ptr, + n, + QUANT_BLOCK: tl.constexpr, + SPLIT_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * SPLIT_SIZE + offsets = block_start + tl.arange(0, SPLIT_SIZE) + mask = offsets < n + out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK) + tl.store(out_ptr + offsets, out_dq, mask) + + +def dequant_8bit_blockwise( + a: torch.Tensor, + absmax: torch.Tensor, + quant_state_code: torch.Tensor, + quant_blocksize: int = 64, + dtype: torch.dtype = None, + out: torch.Tensor = None, +): + n = a.numel() + if out is None: + if dtype is None: + raise ValueError("If out is None, dtype must be specified") + out = torch.empty_like(a, dtype=dtype, device=a.device) + + SPLIT_SIZE = 256 + # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),) + grid = (triton.cdiv(n, SPLIT_SIZE),) + dequant_8bit_kernel[grid]( + a, + out, + quant_state_code, + absmax, + n, + quant_blocksize, + SPLIT_SIZE, + ) + return out + + +# @triton.autotune( +# configs=[ +# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 1}), +# triton.Config({"SPLIT_NUM_BLOCKS": 2}), +# ], +# key=["n_elements"], +# ) +@triton.jit +def quantize_8bit_blockwise_kernel( + A_ptr, + code_ptr, + absmax_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + CODE_SIZE: tl.constexpr, + SPLIT_NUM_BLOCKS: tl.constexpr, +): + block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS + thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) + + offsets = block_start_idx * BLOCK_SIZE + thread_idx + mask = offsets < n_elements + + A = tl.load(A_ptr + offsets, mask=mask, other=0.0) + + quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS) + tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax) + tl.store(out_ptr + offsets, quantized, mask=mask) + + +def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None): + n = A.numel() + blocks = -(n // -blocksize) + + if absmax is None: + absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) + if out is None: + out = torch.empty_like(A.flatten(), dtype=torch.uint8) + + split_num_blocks = 1 + grid = (triton.cdiv(blocks, split_num_blocks),) + # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),) + quantize_8bit_blockwise_kernel[grid]( + A_ptr=A, + code_ptr=code, + absmax_ptr=absmax, + out_ptr=out, + n_elements=n, + BLOCK_SIZE=blocksize, + CODE_SIZE=code.numel(), + SPLIT_NUM_BLOCKS=split_num_blocks, + # num_warps=1, + # num_stages=2, + ) + out = out.reshape(A.shape) + + return out, absmax + + +@triton.jit +def quantize_8bit_blockwise_kernel_util( + a, + code_ptr, + CODE_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_PER_TH: tl.constexpr, +): + # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) + a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE)) + + # Calculating absmax for each block + absmax = tl.max(tl.abs(a_reshaped), axis=1) + + a_normalized = a_reshaped / absmax[:, None] + a_normalized = tl.clamp(a_normalized, -1.0, 1.0) + + lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32) + upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) + + # ceil(log2(code_size)) = 8, actually, in general case should be input parameter + for _ in range(8): + pivot = (lower_pivot + upper_pivot) // 2 + val = tl.load(code_ptr + pivot) + is_higher = a_normalized > val # code[pivot] + lower_pivot = tl.where(is_higher, pivot, lower_pivot) + upper_pivot = tl.where(is_higher, upper_pivot, pivot) + + # Choose closest level + lower_val = tl.load(code_ptr + lower_pivot) + upper_val = tl.load(code_ptr + upper_pivot) + lower_dist = tl.abs(a_normalized - lower_val) + upper_dist = tl.abs(a_normalized - upper_val) + quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) + + # too slow approach + # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :]) + # quantized = tl.argmin(diff, axis=2).to(tl.uint8) + + quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,)) + return quantized_flat, absmax + + +@triton.jit +def dequant_8bit_blockwise_kernel_util( + a_ptr, + offsets, + code_ptr, + absmax_ptr, + mask, + BLOCK_SIZE: tl.constexpr, +): + a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8) + scaled_int8 = tl.load(code_ptr + a, mask) + # Load scales + absmax_offsets = offsets // BLOCK_SIZE + absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last") + # Apply scales + out_dq = scaled_int8 * absmax + return out_dq diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py new file mode 100644 index 000000000..2cd6d8c93 --- /dev/null +++ b/bitsandbytes/backends/triton/kernels_optim.py @@ -0,0 +1,1154 @@ +import math +from typing import Optional + +import torch + +import triton +import triton.language as tl + +# from triton.language.extra import libdevice +from .kernels_8bit_quant import ( + dequant_8bit_blockwise, + dequant_8bit_blockwise_kernel_util, + quantize_8bit_blockwise_kernel_util, + quantize_blockwise_triton, +) + +MOMENTUM = 0 +RMSPROP = 1 +ADAGRAD = 2 +ADAM = 3 +# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels +LION = 4 +ADEMAMIX = 5 + +name2optimizer_id = { + "momentum": MOMENTUM, + "rmsprop": RMSPROP, + "adagrad": ADAGRAD, + "adam": ADAM, + "lion": LION, + "ademamix": ADEMAMIX, +} + + +@triton.jit +def _optimizer_precondition_2state_32bit( + g_ptr, + p_ptr, + state1_ptr, + state2_ptr, + unorm_ptr, + beta1: tl.constexpr, + beta2: tl.constexpr, + eps: tl.constexpr, + weight_decay: tl.constexpr, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale: tl.constexpr, + n_elements, + OPTIMIZER_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_PER_TH: tl.constexpr, +): + """Preprocessing optimizer, computing update norm (2-state optimizer)""" + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) + mask = offsets < n_elements + + g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0) + s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) + s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0) + + g_vals = gnorm_scale * g_vals + + correction1 = 1.0 / (1.0 - beta1_step) + correction2 = 1.0 / (1.0 - beta2_step) + + if OPTIMIZER_ID == 3: # ADAM + s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals + s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals + + s1_vals = s1_vals * correction1 + s2_vals = s2_vals * correction2 + + update_vals = s1_vals / (tl.sqrt(s2_vals) + eps) + + update_norm = update_vals * update_vals + + elif OPTIMIZER_ID == 5: # ADEMAMIX + update_norm = s1_vals + + total_norm = tl.sum(tl.where(mask, update_norm, 0.0)) + + tl.atomic_add(unorm_ptr, total_norm) + + +@triton.jit +def _optimizer_precondition_1state_32bit( + g_ptr, + p_ptr, + state1_ptr, + state2_ptr, + unorm_ptr, + beta1: tl.constexpr, + beta2: tl.constexpr, + eps: tl.constexpr, + weight_decay, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale: tl.constexpr, + n_elements, + OPTIMIZER_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_PER_TH: tl.constexpr, +): + """Preprocessing optimizer, computing update norm (1-state optimizer)""" + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) + mask = offsets < n_elements + + g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0) + s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) + + g_vals = gnorm_scale * g_vals + + if OPTIMIZER_ID == 0: # MOMENTUM + if step == 1: + s1_vals = g_vals + else: + s1_vals = s1_vals * beta1 + g_vals + update_norm = s1_vals * s1_vals + + elif OPTIMIZER_ID == 4: # LION + s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals + update_norm = s1_vals + + elif OPTIMIZER_ID == 1: # RMSPROP + s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals + update_vals = g_vals / (tl.sqrt(s1_vals) + eps) + update_norm = update_vals * update_vals + + elif OPTIMIZER_ID == 2: # ADAGRAD + s1_vals = s1_vals + g_vals * g_vals + update_vals = g_vals / (tl.sqrt(s1_vals) + eps) + update_norm = update_vals * update_vals + + total_norm = tl.sum(tl.where(mask, update_norm, 0.0)) + + tl.atomic_add(unorm_ptr, total_norm) + + +@triton.jit +def _optimizer_update_2state_32bit_triton_kernel( + g_ptr, + p_ptr, + state1_ptr, + state2_ptr, + unorm_ptr, + max_unorm: tl.constexpr, + param_norm, + beta1: tl.constexpr, + beta2: tl.constexpr, + beta3, + alpha, + eps: tl.constexpr, + weight_decay: tl.constexpr, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale: tl.constexpr, + skip_zeros, + n_elements, + OPTIMIZER_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_PER_TH: tl.constexpr, +): + """2-state optimizer kernel""" + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) + mask = offsets < n_elements + + g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) + s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0) + + if OPTIMIZER_ID == 5: # ADEMAMIX + s3_vals = tl.load(state1_ptr + n_elements + offsets, mask=mask, other=0.0) + + g_vals = gnorm_scale * g_vals + + update_scale = 1.0 + if max_unorm > 0.0: + current_unorm = tl.sqrt(tl.load(unorm_ptr)) + if current_unorm > max_unorm * param_norm: + update_scale = (max_unorm * param_norm) / current_unorm + + if OPTIMIZER_ID == 3: # ADAM + s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals + s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals + + correction1 = 1.0 - beta1_step + correction2 = tl.sqrt(1.0 - beta2_step) + step_size = -lr * correction2 / correction1 + + if weight_decay > 0.0: + p_vals = p_vals * (1.0 - lr * weight_decay) + + update_val = update_scale * step_size * (s1_vals / (tl.sqrt(s2_vals) + eps * correction2)) + p_vals = p_vals + update_val + + elif OPTIMIZER_ID == 5: # ADEMAMIX + s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals # m1 + s3_vals = s3_vals * beta3 + (1.0 - beta3) * g_vals # m2 + s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals # nu + + correction1 = 1.0 - beta1_step + correction2 = tl.sqrt(1.0 - beta2_step) + + if weight_decay > 0.0: + p_vals = p_vals * (1.0 - lr * weight_decay) + + mixed_momentum = (s1_vals / correction1) + (alpha * s3_vals) + adaptive_term = (tl.sqrt(s2_vals) / correction2) + eps + p_vals = p_vals - lr * (mixed_momentum / adaptive_term) + + tl.store(p_ptr + offsets, p_vals, mask=mask) + tl.store(state1_ptr + offsets, s1_vals, mask=mask) + tl.store(state2_ptr + offsets, s2_vals, mask=mask) + + if OPTIMIZER_ID == 5: # ADEMAMIX + tl.store(state1_ptr + n_elements + offsets, s3_vals, mask=mask) + + +@triton.jit +def _optimizer_update_1state_32bit_triton_kernel( + g_ptr, + p_ptr, + state1_ptr, + state2_ptr, + unorm_ptr, + max_unorm: tl.constexpr, + param_norm, + beta1: tl.constexpr, + beta2: tl.constexpr, + beta3, + alpha, + eps: tl.constexpr, + weight_decay: tl.constexpr, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale: tl.constexpr, + skip_zeros, + n_elements, + OPTIMIZER_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_PER_TH: tl.constexpr, +): + """1-state optimizer kernel""" + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) + mask = offsets < n_elements + + g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) + + g_vals = gnorm_scale * g_vals + if weight_decay > 0.0: + g_vals = g_vals + p_vals * weight_decay + + update_scale = 1.0 + if max_unorm > 0.0: + current_unorm = tl.sqrt(tl.load(unorm_ptr)) + if current_unorm > max_unorm * param_norm + eps: + update_scale = (max_unorm * param_norm + eps) / current_unorm + + if OPTIMIZER_ID == 0: # MOMENTUM + if step == 1: + s1_vals = g_vals + else: + s1_vals = s1_vals * beta1 + g_vals + + update_val = update_scale * (-lr * s1_vals) + p_vals = p_vals + update_val + + elif OPTIMIZER_ID == 4: # LION + momentum_update = s1_vals * beta1 + (1.0 - beta1) * g_vals + update_val = update_scale * lr * tl.where(momentum_update > 0, 1.0, tl.where(momentum_update < 0, -1.0, 0.0)) + p_vals = p_vals - update_val + + s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals + + elif OPTIMIZER_ID == 1: # RMSPROP + s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals + + update_val = update_scale * lr * g_vals / (tl.sqrt(s1_vals) + eps) + p_vals = p_vals - update_val + + elif OPTIMIZER_ID == 2: # ADAGRAD + s1_vals = s1_vals + g_vals * g_vals + + update_val = lr * g_vals / (tl.sqrt(s1_vals) + eps) + p_vals = p_vals - update_val + + tl.store(p_ptr + offsets, p_vals, mask=mask) + tl.store(state1_ptr + offsets, s1_vals, mask=mask) + + +name2optimizer_32bit_fn = { + "adam": { + "preprocess": _optimizer_precondition_2state_32bit, + "update": _optimizer_update_2state_32bit_triton_kernel, + }, + "ademamix": { + "preprocess": _optimizer_precondition_2state_32bit, + "update": _optimizer_update_2state_32bit_triton_kernel, + }, + "momentum": { + "preprocess": _optimizer_precondition_1state_32bit, + "update": _optimizer_update_1state_32bit_triton_kernel, + }, + "rmsprop": { + "preprocess": _optimizer_precondition_1state_32bit, + "update": _optimizer_update_1state_32bit_triton_kernel, + }, + "adagrad": { + "preprocess": _optimizer_precondition_1state_32bit, + "update": _optimizer_update_1state_32bit_triton_kernel, + }, + "lion": { + "preprocess": _optimizer_precondition_1state_32bit, + "update": _optimizer_update_1state_32bit_triton_kernel, + }, +} + + +def optimizer_update_32bit_impl( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + weight_decay: float, + step: int, + lr: float, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + """ + 32-bit optimizer implemented by Triton + """ + if skip_zeros: + raise NotImplementedError("skip_zeros is not supported on XPU yet") + + BLOCK_SIZE = 256 + N_PER_TH = 1 # Number of blocks processed per thread. + grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),) + optimizer_id = name2optimizer_id[optimizer_name] + fn_preprocess = name2optimizer_32bit_fn[optimizer_name]["preprocess"] + fn_update = name2optimizer_32bit_fn[optimizer_name]["update"] + + # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error. + # For backwards compatibility we precompute the bias correction factors. + beta1_step = beta1**step + beta2_step = beta2**step + + if optimizer_name == "lion": + fn_update[grid]( + g, + p, + state1, + state2, + unorm_vec, + max_unorm, + param_norm, + beta1, + beta2, + beta3, + alpha, + eps, + weight_decay, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale, + skip_zeros, + p.numel(), + optimizer_id, + BLOCK_SIZE, + N_PER_TH, + num_warps=2, + ) + + if max_unorm > 0.0: + unorm_vec.zero_() + fn_preprocess[grid]( + g, + p, + state1, + state2, + unorm_vec, + beta1, + beta2, + eps, + weight_decay, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale, + p.numel(), + optimizer_id, + BLOCK_SIZE, + N_PER_TH, + num_warps=2, + ) + + else: + if max_unorm > 0.0: + unorm_vec.zero_() + fn_preprocess[grid]( + g, + p, + state1, + state2, + unorm_vec, + beta1, + beta2, + eps, + weight_decay, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale, + p.numel(), + optimizer_id, + BLOCK_SIZE, + N_PER_TH, + num_warps=2, + ) + + fn_update[grid]( + g, + p, + state1, + state2, + unorm_vec, + max_unorm, + param_norm, + beta1, + beta2, + beta3, + alpha, + eps, + weight_decay, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale, + skip_zeros, + p.numel(), + optimizer_id, + BLOCK_SIZE, + N_PER_TH, + num_warps=2, + ) + + +########################################### +# Pure torch implementation for reference # +########################################### + + +@torch.compile +def _dequantize_blockwise_pytorch( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Pure PyTorch reference implementation for block-wise dequantization. + """ + if A.numel() == 0: + return torch.empty_like(A, dtype=dtype) + + A_flat = A.flatten() + num_elements = A_flat.numel() + + dequantized_flat = code.to(A.device)[A_flat.long()].to(dtype) + + num_blocks = math.ceil(num_elements / blocksize) + pad_len = num_blocks * blocksize - num_elements + if pad_len > 0: + dequantized_flat = torch.nn.functional.pad(dequantized_flat, (0, pad_len)) + + dequantized_blocks = dequantized_flat.reshape(num_blocks, blocksize) + + rescaled_blocks = dequantized_blocks * absmax.unsqueeze(1).to(dtype) + + rescaled_flat = rescaled_blocks.flatten() + if pad_len > 0: + rescaled_flat = rescaled_flat[:-pad_len] + + return rescaled_flat.reshape(A.shape) + + +@torch.compile +def _quantize_blockwise_pytorch( + A: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Pure PyTorch reference implementation for block-wise quantization. + """ + if A.numel() == 0: + return torch.empty_like(A, dtype=torch.uint8), torch.empty(0, dtype=torch.float32, device=A.device) + + A_flat = A.flatten() + num_elements = A_flat.numel() + + num_blocks = math.ceil(num_elements / blocksize) + + pad_len = num_blocks * blocksize - num_elements + if pad_len > 0: + A_flat = torch.nn.functional.pad(A_flat, (0, pad_len)) + + A_blocks = A_flat.reshape(num_blocks, blocksize) + + absmax = torch.max(torch.abs(A_blocks), dim=1, keepdim=True)[0] + absmax[absmax == 0] = 1.0 + + scaled_blocks = A_blocks / absmax + + # Inefficient but straightforward quantization, takes a lot of memory + diff = torch.abs(scaled_blocks.unsqueeze(2) - code.to(A.device)) + quantized_indices = torch.argmin(diff, dim=2).to(torch.uint8) + + quantized_flat = quantized_indices.flatten() + if pad_len > 0: + quantized_flat = quantized_flat[:-pad_len] + + return quantized_flat.reshape(A.shape), absmax.flatten() + + +# Main updated function +def optimizer_update_8bit_blockwise_pytorch( + p: torch.Tensor, + g: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, # ADEMIX + alpha: float, # ADEMIX + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float, + gnorm_scale: float, + skip_zeros: bool, + # ADEMIX + *, + optimizer_name: str, +) -> None: + """ + Pure PyTorch implementation of the 8-bit block-wise optimizer update step. + This version ensures high-precision updates for float16 parameters. + """ + if skip_zeros: + raise ValueError("skip_zeros is not supported on XPU yet.") + + blocksize = 256 + + with torch.no_grad(): + # Dequantize states to perform updates in 32-bit precision + if optimizer_name == "ademamix" and absmax1.ndim == 2: + # For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked. + s1_1_fp32 = _dequantize_blockwise_pytorch(state1[0], absmax1[0], qmap1, blocksize, torch.float32) + s1_2_fp32 = _dequantize_blockwise_pytorch(state1[1], absmax1[1], qmap1, blocksize, torch.float32) + state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32]) + else: + state1_fp32 = _dequantize_blockwise_pytorch(state1, absmax1, qmap1, blocksize, torch.float32) + + state2_fp32 = None + if state2 is not None: + state2_fp32 = _dequantize_blockwise_pytorch(state2, absmax2, qmap2, blocksize, torch.float32) + + grad = g.float() * gnorm_scale + + # Create a 32-bit copy of the parameter for high-precision updates + p_fp32 = p.data.float() + + if optimizer_name == "adam": + state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = 1.0 - beta2**step + + denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1) + + elif optimizer_name == "ademamix": + m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1] + nu_fp32 = state2_fp32 + + m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3) + nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = math.sqrt(1.0 - beta2**step) + + update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps) + + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + p_fp32.add_(update, alpha=-lr) + state1_fp32 = torch.stack([m1_fp32, m2_fp32]) + + elif optimizer_name == "momentum": + grad.add_(p_fp32, alpha=weight_decay) + if step == 1: + state1_fp32.copy_(grad) + else: + state1_fp32.mul_(beta1).add_(grad) + p_fp32.add_(state1_fp32, alpha=-lr) + + elif optimizer_name == "rmsprop": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + elif optimizer_name == "lion": + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1)) + p_fp32.add_(update_dir, alpha=-lr) + + state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2) + + elif optimizer_name == "adagrad": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.addcmul_(grad, grad, value=1.0) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + else: + raise NotImplementedError( + f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available." + ) + + # Copy the updated 32-bit parameter back to the original tensor + p.data.copy_(p_fp32) + + # Re-quantize states and update state tensors in-place + if optimizer_name == "ademamix": + new_m1_8bit, new_absmax_m1 = _quantize_blockwise_pytorch(state1_fp32[0], qmap1, blocksize) + new_m2_8bit, new_absmax_m2 = _quantize_blockwise_pytorch(state1_fp32[1], qmap1, blocksize) + state1[0].copy_(new_m1_8bit) + state1[1].copy_(new_m2_8bit) + absmax1[0].copy_(new_absmax_m1) + absmax1[1].copy_(new_absmax_m2) + + new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + else: + new_state1_8bit, new_absmax1 = _quantize_blockwise_pytorch(state1_fp32, qmap1, blocksize) + state1.copy_(new_state1_8bit) + absmax1.copy_(new_absmax1) + + if state2_fp32 is not None: + new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + + +####################################### +# Mixed torch + triton implementation # +####################################### + + +# Much more memory efficient due to using triton for quantization/dequantization +def optimizer_update_8bit_blockwise_triton_quant( + p: torch.Tensor, + g: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, # ADEMIX + alpha: float, # ADEMIX + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float, + gnorm_scale: float, + skip_zeros: bool, + # ADEMIX + *, + optimizer_name: str, +) -> None: + """ + Pure PyTorch implementation of the 8-bit block-wise optimizer update step. + This version ensures high-precision updates for float16 parameters. + """ + if skip_zeros and not torch.any(g): + return + + blocksize = 256 + grad = g.float() * gnorm_scale + + with torch.no_grad(): + # Create a 32-bit copy of the parameter for high-precision updates + p_fp32 = p.data.float() + + # Dequantize states to perform updates in 32-bit precision + if optimizer_name == "ademamix" and absmax1.ndim == 2: + # For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked. + s1_1_fp32 = dequant_8bit_blockwise(state1[0], absmax1[0], qmap1, blocksize, dtype=torch.float32) + s1_2_fp32 = dequant_8bit_blockwise(state1[1], absmax1[1], qmap1, blocksize, dtype=torch.float32) + state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32]) + else: + state1_fp32 = dequant_8bit_blockwise(state1, absmax1, qmap1, blocksize, dtype=torch.float32) + + state2_fp32 = None + if state2 is not None: + state2_fp32 = dequant_8bit_blockwise(state2, absmax2, qmap2, blocksize, dtype=torch.float32) + + # Apply optimizer-specific update logic + if optimizer_name == "adam": + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = 1.0 - beta2**step + + denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps) + p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1) + + elif optimizer_name == "ademamix": + m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1] + nu_fp32 = state2_fp32 + + m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3) + nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = math.sqrt(1.0 - beta2**step) + + update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps) + + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + p_fp32.add_(update, alpha=-lr) + state1_fp32 = torch.stack([m1_fp32, m2_fp32]) + + elif optimizer_name == "momentum": + grad.add_(p_fp32, alpha=weight_decay) + if step == 1: + state1_fp32.copy_(grad) + else: + state1_fp32.mul_(beta1).add_(grad) + p_fp32.add_(state1_fp32, alpha=-lr) + + elif optimizer_name == "rmsprop": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + elif optimizer_name == "lion": + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1)) + p_fp32.add_(update_dir, alpha=-lr) + + state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2) + + elif optimizer_name == "adagrad": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.addcmul_(grad, grad, value=1.0) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + else: + raise NotImplementedError( + f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available." + ) + + # Copy the updated 32-bit parameter back to the original tensor + p.data.copy_(p_fp32) + + # Re-quantize states and update state tensors in-place + if optimizer_name == "ademamix": + new_m1_8bit, new_absmax_m1 = quantize_blockwise_triton(state1_fp32[0], qmap1, blocksize) + new_m2_8bit, new_absmax_m2 = quantize_blockwise_triton(state1_fp32[1], qmap1, blocksize) + state1[0].copy_(new_m1_8bit) + state1[1].copy_(new_m2_8bit) + absmax1[0].copy_(new_absmax_m1) + absmax1[1].copy_(new_absmax_m2) + + new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + else: + new_state1_8bit, new_absmax1 = quantize_blockwise_triton(state1_fp32, qmap1, blocksize) + state1.copy_(new_state1_8bit) + absmax1.copy_(new_absmax1) + + if state2_fp32 is not None: + new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + + +######################### +# Triton implementation # +######################### + + +@triton.jit +def _optimizer_update_1state_8bit_blockwise_triton_kernel( + # Tensors + p_ptr, + g_ptr, + state1_ptr, + state2_ptr, + beta1: tl.constexpr, + beta2: tl.constexpr, + beta3, + alpha, + eps: tl.constexpr, + step, + beta1_step, + beta2_step, + lr, + qmap1_ptr, + qmap2_ptr, + absmax1_ptr, + absmax2_ptr, + weight_decay, + gnorm_scale, + # Meta-parameters + n_elements, + BLOCK_SIZE_N: tl.constexpr, + N_PER_TH: tl.constexpr, + OPTIMIZER_ID: tl.constexpr, +): + """ + Triton kernel for 8-bit optimizers that use one momentum state. + Supports: Momentum, RMSprop, Adagrad, Lion. + """ + # 1. Boilerplate: pid, offsets, mask + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH) + mask = offsets < n_elements + + # 2. Load and dequantize tensors + g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale + p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + + # 3. Optimizer-specific updates + # LION + if weight_decay > 0.0 and OPTIMIZER_ID == 2: + p *= 1.0 - lr * weight_decay + # Apply weight decay for momentum, rmsprop, adagrad + elif weight_decay > 0.0: + g += p * weight_decay + + # Momentum update + if OPTIMIZER_ID == 0: # MOMENTUM + if step == 1: + s1 = g + else: + s1 = s1 * beta1 + g + p -= lr * s1 + + # RMSprop update + elif OPTIMIZER_ID == 1: # RMSPROP + s1 = s1 * beta1 + (1.0 - beta1) * g * g + p -= lr * (g / (tl.sqrt(s1) + eps)) + + # Adagrad update + elif OPTIMIZER_ID == 2: # ADAGRAD + s1 += g * g + p -= lr * (g / (tl.sqrt(s1) + eps)) + + # Lion update + elif OPTIMIZER_ID == 4: # LION + val = s1 * beta1 + (1.0 - beta1) * g + update = tl.where(val > 0.0, 1.0, tl.where(val < 0.0, -1.0, 0.0)) + p -= lr * update + s1 = s1 * beta2 + (1.0 - beta2) * g + + # 4. Store updated parameter and requantized state + tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) + s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + offsets, s1_codes, mask=mask) + tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1) + + +@triton.jit +def _optimizer_update_2state_8bit_blockwise_triton_kernel( + # Tensors + p_ptr, + g_ptr, + state1_ptr, + state2_ptr, + beta1: tl.constexpr, + beta2: tl.constexpr, + # ademamix changes alpha and beta3 + beta3, + # ademamix changes alpha and beta3 + alpha, + eps: tl.constexpr, + step, + beta1_step, + beta2_step, + lr, + qmap1_ptr, + qmap2_ptr, + absmax1_ptr, + absmax2_ptr, + weight_decay: tl.constexpr, + gnorm_scale: tl.constexpr, + # Meta-parameters + n_elements, + BLOCK_SIZE_N: tl.constexpr, + N_PER_TH: tl.constexpr, + OPTIMIZER_ID: tl.constexpr, +): + """ + Triton kernel for 8-bit optimizers that use two momentum states. + Supports: Adam, AdEMAMix. + """ + # 1. Boilerplate: pid, offsets, mask + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH) + mask = offsets < n_elements + + # 2. Load and dequantize tensors + g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale + p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + + # 3. Optimizer-specific updates + if OPTIMIZER_ID == 3: # ADAM + s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + s2 = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N) + + s1 = s1 * beta1 + (1.0 - beta1) * g + s2 = s2 * beta2 + (1.0 - beta2) * g * g + + # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error. + # For backwards compatibility we precompute the bias correction factors. + # bias_correction1 = 1.0 - libdevice.pow(beta1, step) + # bias_correction2 = 1.0 - libdevice.pow(beta2, step) + bias_correction1 = 1.0 - beta1_step + bias_correction2 = 1.0 - beta2_step + + if weight_decay > 0.0: + p *= 1.0 - lr * weight_decay + + denom = tl.sqrt(s2) / tl.sqrt(bias_correction2) + eps + p -= (lr / bias_correction1) * (s1 / denom) + + # Store updated parameter + tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) + + # Requantize and store states + s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + offsets, s1_codes, mask=mask) + tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1) + + s2_codes, new_absmax2 = quantize_8bit_blockwise_kernel_util(s2, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state2_ptr + offsets, s2_codes, mask=mask) + tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax2) + + elif OPTIMIZER_ID == 5: # ADEMAMIX + # AdEMAMix has a stacked state1 (m1, m2) and state2 (nu) + m1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + m2 = dequant_8bit_blockwise_kernel_util( + state1_ptr + n_elements, + offsets, + qmap1_ptr, + absmax1_ptr + n_elements // BLOCK_SIZE_N, + mask, + BLOCK_SIZE_N, + ) + nu = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N) + + m1 = m1 * beta1 + (1.0 - beta1) * g + m2 = m2 * beta3 + (1.0 - beta3) * g + nu = nu * beta2 + (1.0 - beta2) * g * g + + # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error. + # For backwards compatibility we precompute the bias correction factors. + # bias_correction1 = 1.0 - libdevice.pow(beta1, step) + # bias_correction2 = tl.sqrt(1.0 - libdevice.pow(beta2, step)) + bias_correction1 = 1.0 - beta1_step + bias_correction2 = tl.sqrt(1.0 - beta2_step) + + update = (m1 / bias_correction1 + alpha * m2) / (tl.sqrt(nu) / bias_correction2 + eps) + + if weight_decay > 0.0: + p *= 1.0 - lr * weight_decay + + p -= lr * update + + # Store updated parameter + tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) + + # Requantize and store all three states + m1_codes, new_absmax_m1 = quantize_8bit_blockwise_kernel_util(m1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + offsets, m1_codes, mask=mask) + tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_m1) + + m2_codes, new_absmax_m2 = quantize_8bit_blockwise_kernel_util(m2, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + n_elements + offsets, m2_codes, mask=mask) + tl.store( + absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH) + n_elements // BLOCK_SIZE_N, + new_absmax_m2, + ) + + nu_codes, new_absmax_nu = quantize_8bit_blockwise_kernel_util(nu, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state2_ptr + offsets, nu_codes, mask=mask) + tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_nu) + + +name2optimizer_fn = { + "momentum": _optimizer_update_1state_8bit_blockwise_triton_kernel, + "rmsprop": _optimizer_update_1state_8bit_blockwise_triton_kernel, + "adagrad": _optimizer_update_1state_8bit_blockwise_triton_kernel, + "adam": _optimizer_update_2state_8bit_blockwise_triton_kernel, + "lion": _optimizer_update_1state_8bit_blockwise_triton_kernel, + "ademamix": _optimizer_update_2state_8bit_blockwise_triton_kernel, +} + + +def optimizer_update_8bit_blockwise_impl( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + if skip_zeros: + raise NotImplementedError("skip_zeros is not supported on XPU yet") + + if optimizer_name == "ademamix": + # Handle AdEMAMIX's stacked state tensors + if state1.dim() < 2 or state1.shape[0] != 2: + raise ValueError( + f"For ademamix, state1 must be a stacked tensor of shape (2, ...), but got {state1.shape}" + ) + if absmax1.dim() < 2 or absmax1.shape[0] != 2: + raise ValueError( + f"For ademamix, absmax1 must be a stacked tensor of shape (2, ...), but got {absmax1.shape}" + ) + + BLOCK_SIZE = 256 + N_PER_TH = 1 # Number of blocks processed per thread. + grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),) + fn = name2optimizer_fn[optimizer_name] + optimizer_id = name2optimizer_id[optimizer_name] + + # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error. + # For backwards compatibility we precompute the bias correction factors. + beta1_step = beta1**step + beta2_step = beta2**step + + fn[grid]( + p, + g, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + beta1_step, + beta2_step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + p.numel(), + BLOCK_SIZE_N=BLOCK_SIZE, + N_PER_TH=N_PER_TH, + OPTIMIZER_ID=optimizer_id, + num_warps=2, + ) + + +# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_pytorch +# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_pytorch_impl) +# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_quant +# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_triton_quant) +optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_impl diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 1e2802ab5..66bff3c94 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -1,30 +1,25 @@ from collections.abc import Sequence +from typing import Optional import torch -from . import triton_kernels +from . import kernels_4bit, kernels_8bit_quant, kernels_optim # currently codes unused, kept for reference # Should be the same for quant/dequant # from bitsandbytes.functional import get_4bit_type # _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu") # _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu") +device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" +torch_accelerator_module = getattr(torch, device_type, torch.cuda) def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) # torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - - absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) - out = torch.empty_like(A.flatten(), dtype=torch.uint8) - - triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out) - out = out.reshape(A.shape) - - return out, absmax.float() + with torch_accelerator_module.device(A.device): + out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A, code, blocksize) + return out, absmax.float() def dequantize_blockwise( @@ -33,21 +28,24 @@ def dequantize_blockwise( torch._check_is_size(blocksize) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") - - out = torch.empty_like(A, dtype=dtype, device=A.device) - triton_kernels.dequant_int8_blockwise( - A, - code, - absmax, - out, - blocksize, - ) - + with torch_accelerator_module.device(A.device): + out = kernels_8bit_quant.dequant_8bit_blockwise( + A, + absmax, + code, + blocksize, + dtype=dtype, + ) return out def dequantize_blockwise_inplace( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, ) -> None: torch._check_is_size(blocksize) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") @@ -55,13 +53,15 @@ def dequantize_blockwise_inplace( torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - triton_kernels.dequant_int8_blockwise( - A, - code, - absmax, - out, - blocksize, - ) + with torch_accelerator_module.device(A.device): + kernels_8bit_quant.dequant_8bit_blockwise( + A, + absmax, + code, + blocksize, + dtype=dtype, + out=out, + ) def quantize_4bit( @@ -84,9 +84,10 @@ def quantize_4bit( absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype) out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8) - triton_kernels.quantize_4bit_blockwise_triton( - A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out - ) + with torch_accelerator_module.device(A.device): + kernels_4bit.quantize_4bit_blockwise_triton( + A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out + ) packed = out if quant_storage != torch.uint8: @@ -118,8 +119,9 @@ def dequantize_4bit( A = A.squeeze().view(torch.uint8).unsqueeze(1) out = torch.empty(shape, dtype=dtype, device=A.device) + with torch_accelerator_module.device(A.device): + kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) return out @@ -134,7 +136,8 @@ def dequantize_4bit_inplace( ) -> None: torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + with torch_accelerator_module.device(A.device): + kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) def gemv_4bit( @@ -150,17 +153,145 @@ def gemv_4bit( B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device) - triton_kernels._dequantize_4bit_impl_passing_code( - B, - absmax, - blocksize, - code, - dtype=A.dtype, - out=B_dq_triton, - ) + with torch_accelerator_module.device(A.device): + kernels_4bit.dequantize_4bit_impl_passing_code( + B, + absmax, + blocksize, + code, + dtype=A.dtype, + out=B_dq_triton, + ) + + return torch.nn.functional.linear( + A, + B_dq_triton, + bias=None, + ) + + +# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_pytorch +# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_pytorch) # 60ms +# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_triton_quant #2.8ms +# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_triton_quant) # 2.3ms +optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_impl # ~0.95ms for adam + + +def optimizer_update_8bit_blockwise( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + # torch._check( + # g.numel() == p.numel(), + # lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}", + # ) + # compute_dtypes = [torch.float16, torch.bfloat16, torch.float32] - return torch.nn.functional.linear( - A, - B_dq_triton, - bias=None, - ) + # torch._check( + # g.dtype in compute_dtypes, + # lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}", + # ) + # torch._check( + # g.dtype == p.dtype, + # lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}", + # ) + # torch._check( + # state1.dtype == torch.uint8, + # lambda: f"state1 must be uint8, got {state1.dtype}", + # ) + # torch._check( + # qmap1.dtype == absmax1.dtype == torch.float32, + # lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}", + # ) + # if state2 is not None: + # torch._check( + # state2.dtype == torch.uint8, + # lambda: f"state2 must be uint8, got {state2.dtype}", + # ) + # torch._check( + # qmap2.dtype == absmax2.dtype == torch.float32, + # lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", + # ) + + with torch_accelerator_module.device(state1.device): + optimizer_update_8bit_blockwise_impl( + optimizer_name=optimizer_name, + g=g, + p=p, + state1=state1, + state2=state2, + beta1=beta1, + beta2=beta2, + beta3=beta3, + alpha=alpha, + eps=eps, + step=step, + lr=lr, + qmap1=qmap1, + qmap2=qmap2, + absmax1=absmax1, + absmax2=absmax2, + weight_decay=weight_decay, + gnorm_scale=gnorm_scale, + skip_zeros=skip_zeros, + ) + + +def optimizer_update_32bit( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + weight_decay: float, + step: int, + lr: float, + gnorm_scale: float, + skip_zeros=False, +) -> None: + with torch_accelerator_module.device(state1.device): + kernels_optim.optimizer_update_32bit_impl( + optimizer_name=optimizer_name, + g=g, + p=p, + state1=state1, + state2=state2, + unorm_vec=unorm_vec, + max_unorm=max_unorm, + param_norm=param_norm, + beta1=beta1, + beta2=beta2, + beta3=beta3, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + step=step, + lr=lr, + gnorm_scale=gnorm_scale, + skip_zeros=skip_zeros, + ) diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py old mode 100755 new mode 100644 index 1543f3474..34e3d5faa --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -3,22 +3,12 @@ from packaging import version import torch -try: - # to support Intel CPU/XPU (IPEX) backend - import intel_extension_for_pytorch as ipex - - ipex_cpu = ipex if ipex._C._has_cpu() else None - ipex_xpu = ipex if ipex._C._has_xpu() else None -except BaseException: - ipex_cpu = None - ipex_xpu = None - try: import triton # noqa: F401 import triton.language as tl # noqa: F401 triton_available = True -except ImportError as e: +except ImportError: triton_available = False diff --git a/bitsandbytes/backends/xpu/__init__.py b/bitsandbytes/backends/xpu/__init__.py old mode 100755 new mode 100644 diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py old mode 100755 new mode 100644 index 999116c97..a0620dc4b --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -1,14 +1,20 @@ from collections.abc import Sequence -import warnings +import ctypes as ct +import logging +from packaging import version import torch +from bitsandbytes.functional import _get_tensor_stream, get_ptr + from ..._ops import register_kernel -from ..utils import ipex_xpu, triton_available +from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib +from ..utils import triton_available + +logger = logging.getLogger(__name__) -# _int_mm is available in torch starting from 2.7 version, -# but currently it's don't have xpu implementation. -if ipex_xpu and torch.__version__ >= (2, 7): +# _int_mm is available in torch starting from 2.9 version +if version.parse(torch.__version__).release >= version.parse("2.9").release: @register_kernel("bitsandbytes::int8_linear_matmul", "xpu") def _(A: torch.Tensor, B: torch.Tensor): @@ -18,42 +24,209 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0]) -# IPEX should be faster for xpu, so at first checking if it is available. -if ipex_xpu: +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + if dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + m = ct.c_int32(1) + n = ct.c_int32(shapeB[0]) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + if A.dtype == torch.float16: + lib.cgemv_4bit_inference_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemv_4bit_inference_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemv_4bit_inference_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + + +# SYCL should be faster for xpu, so at first checking if it is available. +if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): + logger.info("Register sycl bitsandbytes kernels for XPU") + + # TODO: Remove the triton register when quantization sycl kernel is ready. + if triton_available: + from ..triton import ops as triton_ops + + register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise) + register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit) + register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "xpu")( + triton_ops.optimizer_update_8bit_blockwise + ) + register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit) - @register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu") + @register_kernel("bitsandbytes::dequantize_4bit", "xpu") def _( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, + quant_type: str, shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype) + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out @register_kernel("bitsandbytes::dequantize_blockwise", "xpu") + def _( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype + ) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + @register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu") def _( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, + out: torch.Tensor, + ) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + @register_kernel("bitsandbytes::gemv_4bit", "xpu") + def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, ) -> torch.Tensor: - shape = A.shape - out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device) - # void cdequantize_blockwise_fp32( - # float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) - if dtype == torch.float16: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) - elif dtype == torch.bfloat16: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) - elif dtype == torch.float32: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out - return out.reshape(shape) + @register_kernel("bitsandbytes::gemv_4bit.out", "xpu") + def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, + ) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) elif triton_available: + logger.info("Register triton bitsandbytes kernels for XPU") from ..triton import ops as triton_ops register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise) @@ -63,5 +236,7 @@ def _( register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace) register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) + register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "xpu")(triton_ops.optimizer_update_8bit_blockwise) + register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit) else: - warnings.warn("XPU available but no ipex or triton packages found.") + logger.warning("Register pytorch bitsandbytes kernels for XPU because no native library or triton packages found.") diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index bb301e712..2eb584a66 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -283,6 +283,9 @@ def get_native_library() -> BNBNativeLibrary: binary_path = cuda_binary_path + if torch._C._has_xpu: + binary_path = PACKAGE_DIR / f"libbitsandbytes_xpu{DYNAMIC_LIBRARY_SUFFIX}" + logger.debug(f"Loading bitsandbytes native library from: {binary_path}") # Try to load the library - any errors will propagate up @@ -291,39 +294,32 @@ def get_native_library() -> BNBNativeLibrary: if hasattr(dll, "get_context"): # only a CUDA-built library exposes this return CudaBNBNativeLibrary(dll) - logger.warning( - "The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers and GPU quantization are unavailable." - ) return BNBNativeLibrary(dll) ROCM_GPU_ARCH = get_rocm_gpu_arch() -try: - # to support Intel CPU/GPU (XPU) backend - import intel_extension_for_pytorch as ipex - - ipex_cpu = ipex if ipex._C._has_cpu() else None - ipex_xpu = ipex if ipex._C._has_xpu() else None -except BaseException: - ipex_cpu = None - ipex_xpu = None +HIP_ENVIRONMENT = False +BNB_BACKEND = "CPU" +if torch.version.hip: + HIP_ENVIRONMENT = True + BNB_BACKEND = "ROCm" +elif torch.cuda.is_available(): + BNB_BACKEND = "CUDA" +elif torch._C._has_xpu: + BNB_BACKEND = "XPU" try: - if torch.version.hip: - HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" - else: - HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" - lib = get_native_library() except Exception as e: - error_msg = str(e) - if not (ipex_cpu or ipex_xpu): + if BNB_BACKEND in ("CPU", "XPU"): + lib = ErrorHandlerMockBNBNativeLibrary("XPU/CPU can run without native library.") + else: + error_msg = str(e) logger.error( - f"bitsandbytes library load error: {error_msg}\n If you are using Intel CPU/XPU, please install intel_extension_for_pytorch to enable required ops", + f"bitsandbytes library load error: {error_msg}", exc_info=True, ) - # create a mock with error messaging as fallback - lib = ErrorHandlerMockBNBNativeLibrary(error_msg) + # create a mock with error messaging as fallback + lib = ErrorHandlerMockBNBNativeLibrary(error_msg) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9b446a2de..7cca33dcf 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -13,48 +13,13 @@ from torch import Tensor from typing_extensions import deprecated -from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict +from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib +from .cextension import HIP_ENVIRONMENT, lib name2qmap = {} """C FUNCTIONS FOR OPTIMIZERS""" -str2optimizer32bit = { - "adam": ( - lib.cadam32bit_grad_fp32, - lib.cadam32bit_grad_fp16, - lib.cadam32bit_grad_bf16, - ), - "momentum": ( - lib.cmomentum32bit_grad_32, - lib.cmomentum32bit_grad_16, - ), - "rmsprop": ( - lib.crmsprop32bit_grad_32, - lib.crmsprop32bit_grad_16, - ), - "lion": ( - lib.clion32bit_grad_fp32, - lib.clion32bit_grad_fp16, - lib.clion32bit_grad_bf16, - ), - "adagrad": ( - lib.cadagrad32bit_grad_32, - lib.cadagrad32bit_grad_16, - ), - "lamb": ( - lib.cadam32bit_grad_fp32, - lib.cadam32bit_grad_fp16, - lib.cadam32bit_grad_bf16, - ), - "ademamix": ( - lib.cademamix32bit_grad_fp32, - lib.cademamix32bit_grad_fp16, - lib.cademamix32bit_grad_bf16, - ), -} - str2optimizer8bit = { "adam": ( lib.cadam_static_8bit_grad_32, @@ -82,39 +47,6 @@ ), } -str2optimizer8bit_blockwise = { - "adam": ( - lib.cadam_8bit_blockwise_grad_fp32, - lib.cadam_8bit_blockwise_grad_fp16, - lib.cadam_8bit_blockwise_grad_bf16, - ), - "momentum": ( - lib.cmomentum_8bit_blockwise_grad_fp32, - lib.cmomentum_8bit_blockwise_grad_fp16, - lib.cmomentum_8bit_blockwise_grad_bf16, - ), - "rmsprop": ( - lib.crmsprop_8bit_blockwise_grad_fp32, - lib.crmsprop_8bit_blockwise_grad_fp16, - lib.crmsprop_8bit_blockwise_grad_bf16, - ), - "lion": ( - lib.clion_8bit_blockwise_grad_fp32, - lib.clion_8bit_blockwise_grad_fp16, - lib.clion_8bit_blockwise_grad_bf16, - ), - "adagrad": ( - lib.cadagrad_8bit_blockwise_grad_fp32, - lib.cadagrad_8bit_blockwise_grad_fp16, - lib.cadagrad_8bit_blockwise_grad_bf16, - ), - "ademamix": ( - lib.cademamix_8bit_blockwise_grad_fp32, - lib.cademamix_8bit_blockwise_grad_fp16, - lib.cademamix_8bit_blockwise_grad_bf16, - ), -} - class GlobalPageManager: _instance = None @@ -310,7 +242,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) assert e + p == total_bits - has_sign # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] - pvalues = [] for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)): evalues.append(2**val) @@ -422,8 +353,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): for t in tensors: # NULL pointers and paged tensors are OK. if t is not None and not getattr(t, "is_paged", False): - on_gpu &= t.is_cuda - gpu_ids.add(t.device.index) + on_gpu &= t.device.type != "cpu" + gpu_ids.add((t.device.type, t.device.index)) if not on_gpu: raise RuntimeError( @@ -439,6 +370,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p: # We use the raw stream for performance reasons. + if tensor.device.type == "xpu": + return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index)) return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) @@ -1053,16 +986,6 @@ def dequantize_4bit( if absmax.dtype != torch.float32: absmax = absmax.float() - # IPEX format is different, we need extra process. - if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4": - return torch.ops.bitsandbytes.dequantize_nf4_ipex( - A, - absmax, - quant_state.blocksize, - quant_state.shape, - quant_state.dtype, - ) - if out is not None: torch.ops.bitsandbytes.dequantize_4bit.out( A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out @@ -1252,41 +1175,27 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) - optim_func = None - if g.dtype == torch.float32: - optim_func = str2optimizer32bit[optimizer_name][0] - elif g.dtype == torch.float16: - optim_func = str2optimizer32bit[optimizer_name][1] - elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: - optim_func = str2optimizer32bit[optimizer_name][2] - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", - ) - is_on_gpu([g, p, state1, state2, unorm_vec]) - - with _cuda_device_of(g): - optim_func( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(beta3), - ct.c_float(alpha), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + torch.ops.bitsandbytes.optimizer_update_32bit( + optimizer_name, + g, + p, + state1, + state2, + unorm_vec, + max_unorm, + param_norm, + beta1, + beta2, + beta3, + alpha, + eps, + weight_decay, + step, + lr, + gnorm_scale, + skip_zeros, + ) @deprecated( @@ -1447,47 +1356,29 @@ def optimizer_update_8bit_blockwise( gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: - optim_func = None - - if g.dtype == torch.float32 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][0] - elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif ( - g.dtype == torch.bfloat16 - and state1.dtype == torch.uint8 - and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 - ): - optim_func = str2optimizer8bit_blockwise[optimizer_name][2] - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", - ) - is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) - with _cuda_device_of(g): - optim_func( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(beta3), - ct.c_float(alpha), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + torch.ops.bitsandbytes.optimizer_update_8bit_blockwise( + optimizer_name, + g, + p, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + skip_zeros, + ) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) @@ -1631,25 +1522,6 @@ def gemv_4bit( if state.nested: absmax = dequantize_blockwise(absmax, state.state2) + state.offset - if getattr(state, "ipex", False) and state.quant_type == "nf4": - # compute_dtype: 1 indicates fp16, 2 indicates bf16 - compute_dtype = 2 if A.dtype == torch.bfloat16 else 1 - out = torch.ops.torch_ipex.woq_linear( - A, - B, - "nf4", - state.shape, - state.new_scales, - state.new_zeros, - None, - None, - state.blocksize, - compute_dtype, - 1, - state.compensation, - ) - return out - if out is not None: torch.ops.bitsandbytes.gemv_4bit.out( A, @@ -2214,7 +2086,7 @@ def spmm_coo( assert cooA.values.numel() == nnz assert cooA.cols == B.shape[0] - transposed_B = False if B.is_contiguous() else True + transposed_B = not B.is_contiguous() ldb = B.stride()[(1 if transposed_B else 0)] ldc = B.shape[1] @@ -2263,12 +2135,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): assert cooA.values.numel() == nnz assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}" - transposed_B = False if B.is_contiguous() else True - - ldb = B.stride()[(1 if transposed_B else 0)] - ldc = B.shape[1] - - values, counts = torch.unique(cooA.rowidx, return_counts=True) + _, counts = torch.unique(cooA.rowidx, return_counts=True) offset = counts.cumsum(0).int() max_count, max_idx = torch.sort(counts, descending=True) max_idx = max_idx.int() @@ -2288,11 +2155,8 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): cnnz_rows = ct.c_int32(counts.numel()) cnnz = ct.c_int32(cooA.nnz) crowsA = ct.c_int32(cooA.rows) - ccolsA = ct.c_int32(cooA.cols) crowsB = ct.c_int32(B.shape[1]) ccolsB = ct.c_int32(B.shape[1]) - cldb = ct.c_int32(ldb) - cldc = ct.c_int32(ldc) with _cuda_device_of(B): is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) @@ -2336,49 +2200,3 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): C = 127.0 - - -def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor): - quant_state = linear.weight.quant_state - - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - - quant_state.absmax = absmax - quant_state.nested = False - delattr(quant_state, "state2") - - if x.device.type == "cpu" and ipex_cpu: - converted_weight = _reverse_4bit_compress_format(linear.weight.data) - new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight( - converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), - "nf4", - quant_state.shape, # weight shape - quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales - None, # zero_points - None, # bias - None, # batch_size - quant_state.blocksize, - 2, - ) - elif x.device.type == "xpu" and ipex_xpu: - new_weight = _reverse_4bit_compress_format(linear.weight.data) - new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) - new_zeros = None - compensation = None - new_scales = list(new_scales) - if not linear.training and not x.requires_grad: - new_weight = new_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) - else: - raise ValueError( - "Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.7" - ) - - linear.weight.data = new_weight.data - linear.weight.quant_state.ipex = True - linear.weight.quant_state.new_scales = new_scales - linear.weight.quant_state.new_zeros = new_zeros - linear.weight.quant_state.compensation = compensation diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ba134f52a..1adf75e79 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -12,13 +12,9 @@ import bitsandbytes as bnb from bitsandbytes.cextension import HIP_ENVIRONMENT -from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu +from bitsandbytes.functional import QuantState from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import ( - INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, - OutlierTracer, - _reverse_4bit_compress_format, -) +from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer T = TypeVar("T", bound="torch.nn.Module") @@ -356,6 +352,46 @@ def to(self, *args, **kwargs): return new_param + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func in [torch.chunk, torch.split]: + tensor = args[0] + + result = super().__torch_function__(func, types, args, kwargs) + + if isinstance(result, tuple): + return tuple( + cls( + data=chunk, + requires_grad=tensor.requires_grad, + quant_state=tensor.quant_state, + blocksize=tensor.blocksize, + compress_statistics=tensor.compress_statistics, + quant_type=tensor.quant_type, + quant_storage=tensor.quant_storage, + module=tensor.module, + bnb_quantized=tensor.bnb_quantized, + ) + for chunk in result + ) + else: + return cls( + data=result, + requires_grad=tensor.requires_grad, + quant_state=tensor.quant_state, + blocksize=tensor.blocksize, + compress_statistics=tensor.compress_statistics, + quant_type=tensor.quant_type, + quant_storage=tensor.quant_storage, + module=tensor.module, + bnb_quantized=tensor.bnb_quantized, + ) + + return super().__torch_function__(func, types, args, kwargs) + def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]): if getattr(module.weight, "quant_state", None) is not None: @@ -440,10 +476,9 @@ def __init__( ) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype - self.compute_type_is_set = False if compute_dtype is None else True + self.compute_type_is_set = compute_dtype is not None self.quant_state = None self.quant_storage = quant_storage - self.ipex_linear_is_set = False def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -470,40 +505,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ - if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False): - if self.weight.device.type == "cpu": - original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight( - self.weight, "nf4", self.weight.quant_state.shape, 2 - ) - self.weight.data = _reverse_4bit_compress_format(original_weight.data) - elif self.weight.device.type == "xpu": - self.weight.data = _reverse_4bit_compress_format(self.weight.data.reshape(1, -1)) - - self.weight.quant_state.ipex = False - self.ipex_linear_is_set = False - super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() - def set_ipex_linear(self, x: torch.Tensor): - if ( - not getattr(self.weight.quant_state, "ipex", False) - and self.weight.data.dtype == torch.uint8 - and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 - and self.weight.quant_state.quant_type == "nf4" - ): - if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False): - _enable_ipex_fusion(self, x) - def forward(self, x: torch.Tensor): - # Check if ipex fusion can be used - if not self.ipex_linear_is_set and (ipex_cpu or ipex_xpu): - self.set_ipex_linear(x) - self.ipex_linear_is_set = True - fix_4bit_weight_quant_state_from_module(self) # weights are cast automatically as Int8Params, but the bias has to be cast manually @@ -519,8 +527,7 @@ def forward(self, x: torch.Tensor): x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype) - # IPEX CPU will change weight to 4D so don't need transpose - weight = self.weight.t() if self.weight.dim() == 2 else self.weight + weight = self.weight.t() return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) @@ -675,7 +682,7 @@ def to(self, *args, **kwargs): if device is not None and device.type != "meta" and self.data.device.type == "cpu": if device.type != "cpu" or self.data.dtype != torch.int8: return self._quantize(device) - elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu): + elif self.data.dtype == torch.int8 and device.type == "cpu": self.CB = self.data new_param = Int8Params( @@ -1110,4 +1117,4 @@ def forward(self, x): if self.weight.CB is not None: self.init_8bit_state() - out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias + return bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias diff --git a/bitsandbytes/nn/parametrize.py b/bitsandbytes/nn/parametrize.py new file mode 100644 index 000000000..4a956c7fa --- /dev/null +++ b/bitsandbytes/nn/parametrize.py @@ -0,0 +1,192 @@ +from functools import partial +from typing import Any, Literal, Optional + +import torch +import torch.nn as nn +import torch.nn.utils.parametrize as P + +from .. import functional as F + + +class Bnb4bitParametrization(nn.Module): + """ + A parametrization module that handles dequantization of a 4-bit quantized parameter. + + The parameter data is expected to be already quantized when this parametrization is applied. + This module will dequantize the parameter data to its original floating-point representation + when the forward method is called (i.e. when the parameter is accessed). + + Args: + quant_state (`F.QuantState`): + The quantization state containing the necessary information for dequantization. + """ + + def __init__(self, quant_state: F.QuantState): + super().__init__() + self.quant_state = quant_state + + @torch.no_grad() + def forward(self, quantized_param: torch.Tensor) -> torch.Tensor: + """ + Forward pass to dequantize the parameter. + + Args: + quantized_param (`torch.Tensor`): The quantized parameter tensor (from .original) + + Returns: + `torch.Tensor`: The dequantized parameter tensor in the original shape and dtype. + """ + return F.dequantize_4bit(quantized_param, self.quant_state) + + +def replace_parameter_4bit_prequantized( + module: nn.Module, param_name: str, qs_dict: dict[str, Any], device: torch.device +): + if not hasattr(module, param_name): + raise AttributeError(f"Module does not have parameter '{param_name}'") + + original_param = getattr(module, param_name) + + if not isinstance(original_param, nn.Parameter): + raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter") + + quant_state = F.QuantState.from_dict(qs_dict, device=device) + + # Apply a parametrization to the module to handle dequantization. + P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True) + + # Next, register hooks. + _register_parametrization_hooks(module, param_name) + + +def replace_parameter_4bit( + module: nn.Module, + param_name: str, + compress_statistics: bool = False, + quant_type: Literal["nf4", "fp4"] = "nf4", + blocksize: Optional[int] = None, +): + """ + Replace a module parameter with a 4-bit quantized version using parametrization. + + This function quantizes an existing parameter in a PyTorch module to 4-bit precision + and sets up parametrization to handle automatic dequantization during forward passes. + The original parameter is replaced with quantized data, and a parametrization layer + is registered to manage the quantization state and dequantization process. + + Additional, it registers a state dict post-hook to ensure that the quantization state + is saved correctly when the model's state dict is saved. + + It is useful for MoE models or other scenarios where you want to quantize parameters + outside of nn.Linear layers without changing the model's architecture. + + This feature is experimental and may change in future releases. + + Args: + module (`nn.Module`): + The PyTorch module containing the parameter to be quantized. + param_name (`str`): + The name of the parameter within the module to quantize. + compress_statistics (`bool`, *optional*, defaults to `False`): + Whether to compress quantization statistics to reduce memory usage. + quant_type (`Literal["nf4", "fp4"]`, *optional*, defaults to `"nf4"`): + The quantization format to use. + blocksize (`int`, *optional*, defaults to `None`): + The block size for quantization. If None, uses the default block size. + + Raises: + AttributeError: If the module does not have the specified parameter. + TypeError: If the specified attribute is not an instance of nn.Parameter. + """ + + if not hasattr(module, param_name): + raise AttributeError(f"Module does not have parameter '{param_name}'") + + original_param = getattr(module, param_name) + + if not isinstance(original_param, nn.Parameter): + raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter") + + # Quantize the original parameter. + quantized_data, quant_state = F.quantize_4bit( + original_param.data, + blocksize=blocksize, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + + # Replace the parameter with the quantized data. + setattr(module, param_name, nn.Parameter(quantized_data, requires_grad=False)) + del original_param + + # Apply a parametrization to the module to handle dequantization. + P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True) + + # Next, register hooks. + _register_parametrization_hooks(module, param_name) + + +def _disable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...], output: Any): + P._cache_enabled -= 1 + if not P._cache_enabled: + P._cache = {} + + +def _enable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...]): + P._cache_enabled += 1 + + +def _register_parametrization_hooks(module: nn.Module, param_name: str): + # Register a state dict hook for saving. Note that this requires torch >= 2.5.0. + if torch.__version__ >= (2, 5): + module.register_state_dict_post_hook( + partial( + _parametrized_state_dict_post_hook, + param_name=param_name, + ) + ) + + # Register hooks to enable caching for the dequantization parametrization. + # This helps preserve time and memory when the same quantized parameter + # is accessed multiple times in the forward computation. + module.register_forward_pre_hook(_enable_parametrization_cache) + module.register_forward_hook(_disable_parametrization_cache) + + +def _parametrized_state_dict_post_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + local_metadata: Any, + *, + param_name: str = "weight", + **kwargs: dict[str, Any], +) -> None: + """ + Hook to modify the state dict to include the quantization state. + """ + + original_key = f"{prefix}parametrizations.{param_name}.original" + + if original_key in state_dict: + # Create a clean entry. + # The `parametrizations.{param_name}.original` key will have the quantized data, + # but we would like it to keep it in the state_dict as `{param_name}`. + clean_key = f"{prefix}{param_name}" + state_dict[clean_key] = state_dict.pop(original_key) + + assert P.is_parametrized(module, param_name) + + # Find the parametrization, which should have the quantization state. + parametrization: Bnb4bitParametrization = next( + filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None + ) + + assert parametrization is not None, "Parametrization not found for the parameter." + + quant_state = parametrization.quant_state + + # Next, we need to store the quantization state. + if quant_state is not None: + for k, v in quant_state.as_dict(packed=True).items(): + state_dict[f"{prefix}{param_name}.{k}"] = v diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index a32394bd5..5f225c9ad 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -26,7 +26,7 @@ def __init__( Base AdamW optimizer. Arguments: - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. lr (`float`, defaults to 1e-3): The learning rate. @@ -87,7 +87,7 @@ def __init__( 8-bit AdamW optimizer. Arguments: - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. lr (`float`, defaults to 1e-3): The learning rate. @@ -159,7 +159,7 @@ def __init__( 32-bit AdamW optimizer. Arguments: - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. lr (`float`, defaults to 1e-3): The learning rate. @@ -219,7 +219,7 @@ def __init__( Paged AdamW optimizer. Arguments: - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. lr (`float`, defaults to 1e-3): The learning rate. @@ -241,8 +241,6 @@ def __init__( Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. block_wise (`bool`, defaults to `True`): Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. - is_paged (`bool`, defaults to `False`): - Whether the optimizer is a paged optimizer or not. """ super().__init__( "adam", @@ -279,7 +277,7 @@ def __init__( Paged 8-bit AdamW optimizer. Arguments: - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. lr (`float`, defaults to 1e-3): The learning rate. @@ -303,8 +301,6 @@ def __init__( Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. block_wise (`bool`, defaults to `True`): Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. - is_paged (`bool`, defaults to `False`): - Whether the optimizer is a paged optimizer or not. """ # Validate unsupported parameters if amsgrad: @@ -350,7 +346,7 @@ def __init__( Paged 32-bit AdamW optimizer. Arguments: - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. lr (`float`, defaults to 1e-3): The learning rate. @@ -372,8 +368,6 @@ def __init__( Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. block_wise (`bool`, defaults to `True`): Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. - is_paged (`bool`, defaults to `False`): - Whether the optimizer is a paged optimizer or not. """ super().__init__( "adam", diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index 90c3686fe..fa2af57bc 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -231,9 +231,6 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - params_with_grad = [] - d_p_list = [] - momentum_buffer_list = [] weight_decay = group["weight_decay"] momentum = group["momentum"] dampening = group["dampening"] diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 9c20f9376..ea3ff32c9 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -10,6 +10,7 @@ import torch import bitsandbytes.functional as F +from bitsandbytes.utils import sync_gpu class MockArgs: @@ -64,9 +65,9 @@ def override_config(self, parameters, key=None, value=None, key_value_dict=None) parameters (`torch.Tensor` or `list(torch.Tensors)`): The input parameters. key (`str`): - The hyperparamter to override. + The hyperparameter to override. value: - The hyperparameter values. + The hyperparameter value. key_value_dict (`dict`): A dictionary with multiple key-values to override. @@ -115,7 +116,7 @@ def __init__(self, params, defaults, optim_bits=32, is_paged=False): Base 8-bit optimizer class. Arguments: - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. @@ -271,14 +272,13 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() - overflows = [] - if not self.initialized: self.check_overrides() self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True # if self.is_paged: self.page_mng.prefetch_all() + p = None for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -289,11 +289,11 @@ def step(self, closure=None): self.prefetch_state(p) self.update_step(group, p, gindex, pindex) - torch.cuda.synchronize() - if self.is_paged: - # all paged operation are asynchronous, we need + sync_gpu(p) + if self.is_paged and p is not None: + # all paged operations are asynchronous, we need # to sync to make sure all tensors are in the right state - torch.cuda.synchronize() + sync_gpu(p) return loss @@ -371,7 +371,7 @@ def __init__( Arguments: optimizer_name (`str`): The name of the optimizer. - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. lr (`float`, defaults to 1e-3): The learning rate. @@ -428,7 +428,6 @@ def __init__( if args is None: args = {} args["optim_bits"] = optim_bits - args["percentile_clipping"] = 100 args["min_8bit_size"] = min_8bit_size args["percentile_clipping"] = percentile_clipping args["block_wise"] = block_wise @@ -613,7 +612,7 @@ def __init__( Arguments: optimizer_name (`str`): The name of the optimizer. - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. lr (`float`, defaults to 1e-3): The learning rate. @@ -655,7 +654,6 @@ def __init__( if args is None: args = {} args["optim_bits"] = optim_bits - args["percentile_clipping"] = 100 args["min_8bit_size"] = min_8bit_size args["percentile_clipping"] = percentile_clipping args["block_wise"] = block_wise diff --git a/bitsandbytes/py.typed b/bitsandbytes/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index d9718382b..9c7afc354 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -235,7 +235,7 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non # 2. Quantize B if state.has_fp16_weights: # print('B shape', B.shape) - has_grad = True if (getattr(B, "grad", None) is not None) else False + has_grad = getattr(B, "grad", None) is not None is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) if is_transposed: B = B.contiguous() diff --git a/bitsandbytes/triton/triton_utils.py b/bitsandbytes/triton/triton_utils.py index b706ff1ba..f6bedd8cd 100644 --- a/bitsandbytes/triton/triton_utils.py +++ b/bitsandbytes/triton/triton_utils.py @@ -4,11 +4,8 @@ @functools.lru_cache(None) def is_triton_available(): try: - # torch>=2.2.0 from torch.utils._triton import has_triton, has_triton_package return has_triton_package() and has_triton() - except ImportError: - from torch._inductor.utils import has_triton - - return has_triton() + except Exception: + return False diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 7920e2188..1af07710c 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -38,14 +38,6 @@ def outlier_hook(module, input): hook.remove() -# convert btw standard 4-bit compression format and ipex compression format -def _reverse_4bit_compress_format(weight: torch.Tensor): - out_1 = (weight & 0xF0) >> 4 - out_2 = (weight & 0xF) << 4 - out = out_1 | out_2 - return out - - class OutlierTracer: _instance = None @@ -92,11 +84,6 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False) if rdm: return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() - m = weight.mean(reduction_dim) - mm = m.mean() - mstd = m.std() - zm = (m - mm) / mstd - std = weight.std(reduction_dim) stdm = std.mean() stdstd = std.std() @@ -209,3 +196,10 @@ def unpack_tensor_to_dict(tensor_data): LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3} INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()} + + +def sync_gpu(t: torch.Tensor): + if t.device.type == "cuda": + torch.cuda.synchronize() + elif t.device.type == "xpu": + torch.xpu.synchronize() diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 649f2ee1f..738ae0cd1 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -21,23 +21,34 @@ #define NUM 4 #define NUM_BLOCK 4096 -__device__ static float nf4_data[16] = { - -1.0, - -0.6961928009986877, - -0.5250730514526367, - -0.39491748809814453, - -0.28444138169288635, - -0.18477343022823334, - -0.09105003625154495, - 0.0, - 0.07958029955625534, - 0.16093020141124725, - 0.24611230194568634, - 0.33791524171829224, - 0.44070982933044434, - 0.5626170039176941, - 0.7229568362236023, - 1.0 +__device__ static float fp4_dequantization_lut[8] = { + 0.0f, // 0b000 + 0.005208333333f, // 0b001 + 0.66666667f, // 0b010 + 1.0f, // 0b011 + 0.33333333f, // 0b100 + 0.5f, // 0b101 + 0.16666667f, // 0b110 + 0.25f // 0b111 +}; + +__device__ static float nf4_dequantization_lut[16] = { + -1.0f, // 0b0000 + -0.6961928009986877f, // 0b0001 + -0.5250730514526367f, // 0b0010 + -0.39491748809814453f, // 0b0011 + -0.28444138169288635f, // 0b0100 + -0.18477343022823334f, // 0b0101 + -0.09105003625154495f, // 0b0110 + 0.0f, // 0b0111 + 0.07958029955625534f, // 0b1000 + 0.16093020141124725f, // 0b1001 + 0.24611230194568634f, // 0b1010 + 0.33791524171829224f, // 0b1011 + 0.44070982933044434f, // 0b1100 + 0.5626170039176941f, // 0b1101 + 0.7229568362236023f, // 0b1110 + 1.0f // 0b1111 }; // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda @@ -51,27 +62,9 @@ __device__ float atomicMax(float* address, float val) { return __int_as_float(old); } -__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) { - float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; - if ((val & 0b0100) == 4) // 0 - if ((val & 0b0010) == 2) // 01 - if ((val & 0b0001) == 1) // 111 - return 0.25000000f * absmax * sign; // 1111 - else - return 0.16666667f * absmax * sign; // 1110 - else if ((val & 0b0001) == 1) // 110 - return 0.50000000f * absmax * sign; // 1101 - else - return 0.33333333f * absmax * sign; // 1100 - else if ((val & 0b0010) == 2) // 10 - if ((val & 0b0001) == 1) // 101 - return 1.00000000f * absmax * sign; // 1011 - else - return 0.66666667f * absmax * sign; // 1010 - else if ((val & 0b0001) == 1) // 100 - return 5.208333333e-03f * absmax * sign; // 1001 - else - return 0.00000000f * absmax * sign; // 1000 +__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) { + float sign = 1.0f - 2 * ((val & 0b1000) >> 3); + return fp4_dequantization_lut[val & 0b111] * sign; } __device__ unsigned char dQuantizeFP4(float x) { @@ -118,51 +111,7 @@ __device__ unsigned char dQuantizeFP4(float x) { return 0b0000 + sign; } -__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { - - // the values for this tree was generated by test_normal_map_tree - // in the file tests/test_functional.py - if ((val & 0b1000) == 8) - if ((val & 0b0100) == 4) // 1 - if ((val & 0b0010) == 2) // 11 - if ((val & 0b0001) == 1) // 111 - return 1.0f; - else - return 0.7229568362236023f; - else if ((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; - else - return 0.44070982933044434f; - else if ((val & 0b0010) == 2) // 10 - if ((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; - else - return 0.24611230194568634f; - else if ((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; - else - return 0.07958029955625534f; - - else if ((val & 0b0100) == 4) // 0 - if ((val & 0b0010) == 2) // 01 - if ((val & 0b0001) == 1) // 011 - return 0.0f; - else - return -0.09105003625154495f; - else if ((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; - else - return -0.28444138169288635f; - else if ((val & 0b0010) == 2) // 00 - if ((val & 0b0001) == 1) // 001 - return -0.39491748809814453f; - else - return -0.5250730514526367f; - else if ((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; - else - return -1.0f; -} +__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; } __device__ unsigned char dQuantizeNF4(float x) { @@ -431,7 +380,6 @@ __global__ void kQuantizeBlockwise( LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } - unsigned char packed_4bit = 0; switch (DATA_TYPE) { case General8bit: #pragma unroll NUM_PER_TH @@ -445,17 +393,15 @@ __global__ void kQuantizeBlockwise( case FP4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH / 2; j++) { - packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; - packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); - qvals[j] = packed_4bit; + qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); } break; case NF4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH / 2; j++) { - packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; - packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); - qvals[j] = packed_4bit; + qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); } break; } @@ -513,8 +459,8 @@ __global__ void case FP4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { - vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); - vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max; } break; case NF4: @@ -2355,7 +2301,7 @@ __global__ void kgemm_4bit_inference( #pragma unroll 16 for (int i = 0; i < 16; i++) - quant_map[i] = nf4_data[i]; + quant_map[i] = nf4_dequantization_lut[i]; //__shared__ T quant_map[16*160]; T local_A[2]; diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 58f6ed065..bef6cffa6 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -19,37 +19,42 @@ #define NUM 4 #define NUM_BLOCK 4096 -__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; +__device__ static float fp4_dequantization_lut[8] = { + 0.0f, // 0b000 + 0.005208333333f, // 0b001 + 0.66666667f, // 0b010 + 1.0f, // 0b011 + 0.33333333f, // 0b100 + 0.5f, // 0b101 + 0.16666667f, // 0b110 + 0.25f // 0b111 +}; + +__device__ static float nf4_dequantization_lut[16] = { + -1.0f, // 0b0000 + -0.6961928009986877f, // 0b0001 + -0.5250730514526367f, // 0b0010 + -0.39491748809814453f, // 0b0011 + -0.28444138169288635f, // 0b0100 + -0.18477343022823334f, // 0b0101 + -0.09105003625154495f, // 0b0110 + 0.0f, // 0b0111 + 0.07958029955625534f, // 0b1000 + 0.16093020141124725f, // 0b1001 + 0.24611230194568634f, // 0b1010 + 0.33791524171829224f, // 0b1011 + 0.44070982933044434f, // 0b1100 + 0.5626170039176941f, // 0b1101 + 0.7229568362236023f, // 0b1110 + 1.0f // 0b1111 +}; // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda // Luckily we have atomicmax and atomicmin in ROCm - -__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) -{ - float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; - if((val & 0b0100) == 4) // 0 - if((val & 0b0010) == 2) //01 - if((val & 0b0001) == 1) // 111 - return 0.25000000f*absmax*sign; // 1111 - else - return 0.16666667f*absmax*sign; // 1110 - else - if((val & 0b0001) == 1) // 110 - return 0.50000000f*absmax*sign; // 1101 - else - return 0.33333333f*absmax*sign; // 1100 - else - if((val & 0b0010) == 2) //10 - if((val & 0b0001) == 1) // 101 - return 1.00000000f*absmax*sign; // 1011 - else - return 0.66666667f*absmax*sign; // 1010 - else - if((val & 0b0001) == 1) // 100 - return 5.208333333e-03f*absmax*sign; // 1001 - else - return 0.00000000f*absmax*sign; // 1000 +__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) { + float sign = 1.0f - 2 * ((val & 0b1000) >> 3); + return fp4_dequantization_lut[val & 0b111] * sign; } __device__ unsigned char dQuantizeFP4(float x) @@ -101,61 +106,7 @@ __device__ unsigned char dQuantizeFP4(float x) return 0b0000+sign; } - -__device__ __forceinline__ float dDequantizeNF4(unsigned char val) -{ - - // the values for this tree was generated by test_normal_map_tree - // in the file tests/test_functional.py - if((val & 0b1000) == 8) - if((val & 0b0100) == 4) // 1 - if((val & 0b0010) == 2) // 11 - if((val & 0b0001) == 1) // 111 - return 1.0f; - else - return 0.7229568362236023f; - else - if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; - else - return 0.44070982933044434f; - else - if((val & 0b0010) == 2) //10 - if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; - else - return 0.24611230194568634f; - else - if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; - else - return 0.07958029955625534f; - - else - if((val & 0b0100) == 4) // 0 - if((val & 0b0010) == 2) //01 - if((val & 0b0001) == 1) // 011 - return 0.0f; - else - return -0.09105003625154495f; - else - if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; - else - return -0.28444138169288635f; - else - if((val & 0b0010) == 2) //00 - if((val & 0b0001) == 1) // 001 - return -0.39491748809814453f; - else - return -0.5250730514526367f; - else - if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; - else - return -1.0f; - -} +__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; } __device__ unsigned char dQuantizeNF4(float x) { @@ -456,7 +407,6 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } - unsigned char packed_4bit = 0; switch(DATA_TYPE) { case General8bit: @@ -473,18 +423,16 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH/2; j++) { - packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; - packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); - qvals[j] = packed_4bit; + qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); } break; case NF4: #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH/2; j++) { - packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; - packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); - qvals[j] = packed_4bit; + qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); } break; } @@ -546,8 +494,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) { - vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); - vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max; } break; case NF4: @@ -2507,7 +2455,7 @@ template __global__ void kgemm_4bit_inference(int M, i #pragma unroll 16 for(int i = 0; i < 16; i++) - quant_map[i] = nf4_data[i]; + quant_map[i] = nf4_dequantization_lut[i]; //__shared__ T quant_map[16*160]; T local_A[2]; diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 9c4cab9cc..b5d9afc6b 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -12,6 +12,9 @@ #if BUILD_MPS // #include #endif +#if BUILD_XPU +#include +#endif #include // Compatibility between HIP/CUDA APIs @@ -308,6 +311,90 @@ void spmm_coo_very_sparse_naive_int8( } #endif +#if BUILD_XPU + +void dequantizeBlockwise_fp16( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_fp4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_nf4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32_fp4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32_nf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16_fp4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16_nf4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void gemv_4bit_inference_fp16( + int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void gemv_4bit_inference_bf16( + int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, + sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference( + m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream + ); +} + +void gemv_4bit_inference_fp32( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +#endif + extern "C" { #if BUILD_CUDA || BUILD_HIP void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } @@ -658,6 +745,88 @@ void cgemm_4bit_inference_naive_fp32( #endif +#if BUILD_XPU + +void cdequantize_blockwise_fp16_fp4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16_nf4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_fp4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_nf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_fp4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_nf4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cgemv_4bit_inference_fp16( + int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemv_4bit_inference_bf16( + int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, + sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemv_4bit_inference_fp32( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +#endif + void cquantize_blockwise_cpu_fp32( float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n ) { diff --git a/csrc/xpu_kernels.cpp b/csrc/xpu_kernels.cpp new file mode 100644 index 000000000..8ee8add98 --- /dev/null +++ b/csrc/xpu_kernels.cpp @@ -0,0 +1,281 @@ +#include "xpu_kernels.h" +#include +#include +#include + +#include + +inline float dDequantizeFP4(unsigned char val) { + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -0.25000000f; + else + return -0.16666667f; + else if ((val & 0b0001) == 1) + return -0.50000000f; + else + return -0.33333333f; + else if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -1.00000000f; + else + return -0.66666667f; + else if ((val & 0b0001) == 1) + return -5.208333333e-03f; + else + return 0.00000000f; + else if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return 0.25000000f; + else + return 0.16666667f; + else if ((val & 0b0001) == 1) + return 0.50000000f; + else + return 0.33333333f; + else if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return 1.00000000f; + else + return 0.66666667f; + else if ((val & 0b0001) == 1) + return 5.208333333e-03f; + else + return 0.00000000f; +} + +inline float dDequantizeNF4(unsigned char val) { + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) // 1 + if ((val & 0b0010) == 2) // 11 + if ((val & 0b0001) == 1) // 111 + return 1.0f; //*1111 + else + return 0.7229568362236023f; //*1110 + else if ((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; //*1101 + else + return 0.44070982933044434f; //*1100 + else if ((val & 0b0010) == 2) // 10 + if ((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; //*1011 + else + return 0.24611230194568634f; //*1010 + else if ((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; //*1001 + else + return 0.07958029955625534f; //*1000 + + else if ((val & 0b0100) == 4) // 0 + if ((val & 0b0010) == 2) // 01 + if ((val & 0b0001) == 1) // 011 + return 0.0f; //*0111 + else + return -0.09105003625154495f; //*0110 + else if ((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; //*0101 + else + return -0.28444138169288635f; //*0100 + else if ((val & 0b0010) == 2) // 00 + if ((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; //*0011 + else + return -0.5250730514526367f; //*0010 + else if ((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; //*0001 + else + return -1.0f; //*0000 +} + +template +SYCL_EXTERNAL void kDequantizeBlockwise::operator()(sycl::nd_item<1> item) const { + const int base_idx = item.get_group(0) * TILE_SIZE; + size_t local_idx = item.get_local_id(0) * NUM_PER_TH; + float local_abs_max = -FLT_MAX; + int local_load_idx = 0; + int local_store_idx = 0; + + uint8_t qvals[NUM_PER_TH]; + T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; + + if (DATA_TYPE > 0) { + local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx); + local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2); + } else { + local_load_idx = sycl::min(TILE_SIZE, n - base_idx); + local_store_idx = local_load_idx; + } + + // Avoid expensive division by the blocksize (as blocksize will always be a + // power-of-2) + local_abs_max = absmax[(base_idx + local_idx) >> (31 - std::countl_zero(blocksize))]; + + if (local_idx + NUM_PER_TH < local_load_idx) { + reinterpret_cast(&)[NUM_PER_TH]>(qvals)[0] = + reinterpret_cast*>(A)[(base_idx + local_idx) / NUM_PER_TH]; + } else { +#pragma unroll NUM_PER_TH + for (int i = 0; i < NUM_PER_TH; i++) { + if (local_idx + i < local_load_idx) { + qvals[i] = A[base_idx + local_idx + i]; + } else { + qvals[i] = (uint8_t)0; + } + } + } + + switch (DATA_TYPE) { + case General8bit: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) + vals[j] = code[qvals[j]] * local_abs_max; + break; + case FP4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeFP4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeFP4(qvals[j] & 0x0F) * local_abs_max; + } + break; + case NF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max; + } + break; + } + + const int local_dst_size = (DATA_TYPE > 0) ? NUM_PER_TH * 2 : NUM_PER_TH; + int local_dst_idx = (DATA_TYPE > 0) ? local_idx * 2 : local_idx; + + if (local_dst_idx + local_dst_size < local_store_idx) { + reinterpret_cast*>( + out + )[(((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx) / local_dst_size] = + reinterpret_cast(&)[local_dst_size]>(vals)[0]; + } else { +#pragma unroll NUM_PER_TH + for (int i = 0; i < local_dst_size; i++) { + if (local_dst_idx + i < local_store_idx) { + out[((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx + i] = vals[i]; + } + } + } +} + +template +SYCL_EXTERNAL void + kgemv_4bit_inference::operator()(sycl::nd_item<1> item) const { + size_t idx = item.get_local_id(); + const int sg_idx = idx / SUBG_SIZE; + const int sg_lane = idx % SUBG_SIZE; + const int num_values_4bit = SUBG_SIZE; + const int row_B = NUM_PER_THREAD * item.get_group().get_group_id() + sg_idx; + const int offset_B = ldb * row_B; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + T local_absmax = T(0.0f); + + if (idx < 16) { + quant_map[idx] = T(datatype[idx]); + } + + item.barrier(sycl::access::fence_space::local_space); + + for (int inner_idx = sg_lane * num_values_4bit; inner_idx < K; inner_idx += SUBG_SIZE * num_values_4bit) { + const int inner_idx_halved = inner_idx / 2; + + // Avoid expensive division by the blocksize (as blocksize will always be a + // power-of-2) + const int absidx = ((2 * offset_B) + inner_idx) >> (31 - std::countl_zero((unsigned int)blocksize)); + local_absmax = absmax[absidx]; + + if (row_B < N) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + reinterpret_cast(&)[num_values_8bit]>(local_B_4bit)[0] = + reinterpret_cast*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { +#pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { +#pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + for (int i = 0; i < 4; i++) { +#pragma unroll + for (int k = 0; k < num_values_8bit / 4; k++) { + local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; + local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; + } + + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + if (BITS == 16) { + reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[0] = + reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[0] = + reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[1] = + reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } + + } else { +#pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); + } + +// accumulate in float for accuracy; +#pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + local_C += (float)(local_A[k] * local_B[k]); + } + } + } + + local_C = sycl::reduce_over_group(item.get_sub_group(), local_C, sycl::plus<>()); + + if (row_B < N && sg_lane == 0) + out[row_B] = T(local_C); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; + +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; + +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; + +template class kgemv_4bit_inference; +template class kgemv_4bit_inference; +template class kgemv_4bit_inference; diff --git a/csrc/xpu_kernels.h b/csrc/xpu_kernels.h new file mode 100644 index 000000000..caa7e6716 --- /dev/null +++ b/csrc/xpu_kernels.h @@ -0,0 +1,52 @@ +#include +#include + +#ifndef xpu_kernels +#define xpu_kernels + +template class kDequantizeBlockwise { + public: + SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const; + + kDequantizeBlockwise(float* code_, uint8_t* A_, float* absmax_, T* out_, const int blocksize_, const int n_) + : code(code_), A(A_), absmax(absmax_), out(out_), blocksize(blocksize_), n(n_) {} + + private: + float* code; + uint8_t* A; + float* absmax; + T* out; + const int blocksize; + const int n; +}; + +template class kgemv_4bit_inference { + public: + SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const; + + kgemv_4bit_inference( + int M_, int N_, int K_, T* A_, unsigned char* B_, float* absmax_, const float* datatype_, T* out_, int lda_, + int ldb_, int ldc_, int blocksize_ + ) + : M(M_), N(N_), K(K_), A(A_), B(B_), absmax(absmax_), datatype(datatype_), out(out_), lda(lda_), ldb(ldb_), + ldc(ldc_), blocksize(blocksize_), quant_map() {} + + void sycl_ker_local_memory_creation(sycl::handler& cgh) { quant_map = sycl::local_accessor(16, cgh); } + + private: + int M; + int N; + int K; + T* A; + unsigned char* B; + float* absmax; + const float* datatype; + T* out; + int lda; + int ldb; + int ldc; + int blocksize; + sycl::local_accessor quant_map; +}; + +#endif diff --git a/csrc/xpu_ops.cpp b/csrc/xpu_ops.cpp new file mode 100644 index 000000000..aa6ac808f --- /dev/null +++ b/csrc/xpu_ops.cpp @@ -0,0 +1,102 @@ +#include +#include +#include + +template +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, sycl::queue* stream +) { + auto& queue = *stream; + const int workgroup_size = 128; + const int num_per_th = 4; + const int tile_size = workgroup_size * num_per_th; + if (DATA_TYPE > 0) { + const int workgroup_num = (n + tile_size * 2 - 1) / (tile_size * 2); + sycl::range<1> local_range{(size_t)workgroup_size}; + sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; + kDequantizeBlockwise kfn(code, A, absmax, out, blocksize / 2, n); + sycl_kernel_submit( + sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn + ); + } else { + const int workgroup_num = (n + tile_size - 1) / tile_size; + sycl::range<1> local_range{(size_t)workgroup_size}; + sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; + kDequantizeBlockwise kfn(code, A, absmax, out, blocksize, n); + sycl_kernel_submit( + sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn + ); + } +} + +template +void gemv_4bit_inference( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, sycl::queue* stream +) { + + auto& queue = *stream; + + const size_t GROUP_SIZE = 128; // workgroup_size + const size_t SUBG_SIZE = 32; // subgroup_size + const size_t NUM_PER_THREAD = GROUP_SIZE / SUBG_SIZE; + size_t workgroup_num = (n + NUM_PER_THREAD - 1) / NUM_PER_THREAD; + + kgemv_4bit_inference kfn( + m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize + ); + + sycl_comp_kernel_submit( + sycl::nd_range<1>(sycl::range<1>(GROUP_SIZE * workgroup_num), sycl::range<1>(GROUP_SIZE)), queue, kfn + ); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +); + +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +); + +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +); + +template void gemv_4bit_inference( + int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream +); +template void gemv_4bit_inference( + int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, + sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +); +template void gemv_4bit_inference( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, sycl::queue* stream +); diff --git a/csrc/xpu_ops.h b/csrc/xpu_ops.h new file mode 100644 index 000000000..142d6c161 --- /dev/null +++ b/csrc/xpu_ops.h @@ -0,0 +1,46 @@ +#ifndef xpu_ops_H +#define xpu_ops_H + +#include +#include +#include +#include + +#include +#include + +#include + +template +static inline void sycl_kernel_submit(sycl::nd_range range, sycl::queue q, ker_t ker) { + auto cgf = [&](::sycl::handler& cgh) + [[sycl::reqd_sub_group_size(subgroup_size)]] { cgh.parallel_for(range, ker); }; + q.submit(cgf); +} + +template +static inline void sycl_comp_kernel_submit(sycl::nd_range range, sycl::queue q, ker_t ker) { + auto cgf = [&](::sycl::handler& cgh) [[sycl::reqd_sub_group_size(subgroup_size)]] { + ker.sycl_ker_local_memory_creation(cgh); + cgh.parallel_for(range, ker); + }; + q.submit(cgf); +} + +typedef enum DataType_t { + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +template +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int workgroup_size, const int n, sycl::queue* stream +); +template +void gemv_4bit_inference( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, sycl::queue* stream +); + +#endif diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index e61ce4655..daa06a3c6 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -16,17 +16,19 @@ Welcome to the installation guide for the `bitsandbytes` library! This document ## CUDA[[cuda]] -`bitsandbytes` is currently supported on NVIDIA GPUs with [Compute Capability](https://developer.nvidia.com/cuda-gpus) 5.0+. -The library can be built using CUDA Toolkit versions as old as **11.6** on Windows and **11.4** on Linux. +`bitsandbytes` is currently supported on NVIDIA GPUs with [Compute Capability](https://developer.nvidia.com/cuda-gpus) 6.0+. +The library can be built using CUDA Toolkit versions as old as **11.8**. | **Feature** | **CC Required** | **Example Hardware Requirement** | |---------------------------------|-----------------|---------------------------------------------| -| LLM.int8() | 7.5+ | Turing (RTX 20 series, T4) or newer GPUs | -| 8-bit optimizers/quantization | 5.0+ | Maxwell (GTX 900 series, TITAN X, M40) or newer GPUs | -| NF4/FP4 quantization | 5.0+ | Maxwell (GTX 900 series, TITAN X, M40) or newer GPUs | +| LLM.int8() | 7.5+ | Turing (RTX 20 series, T4) or newer GPUs | +| 8-bit optimizers/quantization | 6.0+ | Pascal (GTX 10X0 series, P100) or newer GPUs| +| NF4/FP4 quantization | 6.0+ | Pascal (GTX 10X0 series, P100) or newer GPUs| > [!WARNING] -> Support for Maxwell GPUs is deprecated and will be removed in a future release. For the best results, a Turing generation device or newer is recommended. +> Support for Maxwell GPUs is deprecated and will be removed in a future release. +> Maxwell support is not included in PyPI distributions from `v0.48.0` on and must be built from source. +> For the best results, a Turing generation device or newer is recommended. ### Installation via PyPI[[cuda-pip]] @@ -36,12 +38,12 @@ The currently distributed `bitsandbytes` packages are built with the following c | **OS** | **CUDA Toolkit** | **Host Compiler** | **Targets** |--------------------|------------------|----------------------|-------------- -| **Linux x86-64** | 11.8 - 12.6 | GCC 11.2 | sm50, sm60, sm75, sm80, sm86, sm89, sm90 -| **Linux x86-64** | 12.8 | GCC 11.2 | sm75, sm80, sm86, sm89, sm90, sm100, sm120 +| **Linux x86-64** | 11.8 - 12.6 | GCC 11.2 | sm60, sm70, sm75, sm80, sm86, sm89, sm90 +| **Linux x86-64** | 12.8 - 12.9 | GCC 11.2 | sm70, sm75, sm80, sm86, sm89, sm90, sm100, sm120 | **Linux aarch64** | 11.8 - 12.6 | GCC 11.2 | sm75, sm80, sm90 -| **Linux aarch64** | 12.8 | GCC 11.2 | sm75, sm80, sm90, sm100 +| **Linux aarch64** | 12.8 - 12.9 | GCC 11.2 | sm75, sm80, sm90, sm100, sm120 | **Windows x86-64** | 11.8 - 12.6 | MSVC 19.43+ (VS2022) | sm50, sm60, sm75, sm80, sm86, sm89, sm90 -| **Windows x86-64** | 12.8 | MSVC 19.43+ (VS2022) | sm75, sm80, sm86, sm89, sm90, sm100, sm120 +| **Windows x86-64** | 12.8 - 12.9 | MSVC 19.43+ (VS2022) | sm70, sm75, sm80, sm86, sm89, sm90, sm100, sm120 Use `pip` or `uv` to install: @@ -67,7 +69,7 @@ For example, to install a compiler and CMake on Ubuntu: apt-get install -y build-essential cmake ``` -You should also install CUDA Toolkit by following the [NVIDIA CUDA Installation Guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) guide. The current minimum supported CUDA Toolkit version that we test with is **11.8**. +You should also install CUDA Toolkit by following the [NVIDIA CUDA Installation Guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) guide. The current minimum supported CUDA Toolkit version that we support is **11.8**. ```bash git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ @@ -84,7 +86,7 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise Compilation from source on Windows systems require Visual Studio with C++ support as well as an installation of the CUDA Toolkit. -To compile from source, you need CMake >= **3.22.1** and Python >= **3.9** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA. The current minimum supported CUDA Toolkit version that we test with is **11.8**. +To compile from source, you need CMake >= **3.22.1** and Python >= **3.9** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA. The current minimum supported CUDA Toolkit version that we support is **11.8**. ```bash git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ @@ -138,8 +140,8 @@ We provide an early preview of support for AMD and Intel hardware as part of a d | **Backend** | **Supported Versions** | **Python versions** | **Architecture Support** | **Status** | |-------------|------------------------|---------------------------|-------------------------|------------| | **AMD ROCm** | 6.1+ | 3.10+ | minimum CDNA - `gfx90a`, RDNA - `gfx1100` | Alpha | -| **Intel CPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel CPU | Alpha | -| **Intel GPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel GPU | Experimental | +| **Intel CPU** | v2.4.0+ | 3.10+ | Intel CPU | Alpha | +| **Intel GPU** | v2.7.0+ | 3.10+ | Intel GPU | Experimental | | **Ascend NPU** | 2.1.0+ (`torch_npu`) | 3.10+ | Ascend NPU | Experimental | For each supported backend, follow the respective instructions below: @@ -179,7 +181,6 @@ pip install torch --index-url https://download.pytorch.org/whl/rocm6.3/ * A compatible PyTorch version with Intel XPU support is required. It is recommended to use the latest stable release. See [Getting Started on Intel GPU](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html) for guidance. -* The [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/) is recommended for performance improvements. @@ -235,27 +236,18 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise -#### Intel CPU + XPU +#### Intel CPU + GPU(XPU) - -If you are using Intel CPU on Linux or Intel XPU on Linux/Windows, please follow the [instruction](https://pytorch-extension.intel.com/) or the following command to install intel_extension_for_pytorch so you can get better performance. - -CPU: `pip install intel_extension_for_pytorch` -XPU: `pip install intel_extension_for_pytorch --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/` - -Install bitsandbytes: -CPU: Need to build CPU C++ codes +CPU needs to build CPU C++ codes, while XPU needs to build sycl codes. +Run `export bnb_device=xpu` if you are using xpu, run `export bnb_device=cpu` if you are using cpu. ``` git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ -cmake -DCOMPUTE_BACKEND=cpu -S . +cmake -DCOMPUTE_BACKEND=$bnb_device -S . make -pip install . -``` -XPU: -``` -pip install git+https://github.com/bitsandbytes-foundation/bitsandbytes.git +pip install -e . ``` + diff --git a/install_cuda.py b/install_cuda.py index c87deaedf..0122be04b 100644 --- a/install_cuda.py +++ b/install_cuda.py @@ -87,7 +87,7 @@ def main(): # Install CUDA version(s) if version == "all": - for ver in cuda_versions.keys(): + for ver in cuda_versions: install_cuda(ver, base_path, download_path) elif version in cuda_versions: install_cuda(version, base_path, download_path) diff --git a/pyproject.toml b/pyproject.toml index af4c8c240..61b35c648 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["setuptools >= 63.0.0"] -build-backend = "setuptools.build_meta" +requires = ["scikit-build-core", "setuptools >= 63.0.0"] +build-backend = "scikit_build_core.setuptools.build_meta" [project] name = "bitsandbytes" @@ -42,8 +42,9 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence" ] dependencies = [ - "torch>=2.2,<3", - "numpy>=1.17" + "torch>=2.3,<3", + "numpy>=1.17", + "packaging>=20.9" ] [project.urls] @@ -71,7 +72,7 @@ test = [ ] [tool.setuptools] -package-data = { "*" = ["libbitsandbytes*.*"] } +package-data = { "*" = ["libbitsandbytes*.*", "py.typed"] } [tool.setuptools.packages.find] include = ["bitsandbytes*"] @@ -123,11 +124,10 @@ select = [ ignore = [ "B007", # Loop control variable not used within the loop body (TODO: enable) "B028", # Warning without stacklevel (TODO: enable) - "E501", # Supress line-too-long warnings: trust yapf's judgement on this one. + "E501", # Suppress line-too-long warnings: trust yapf's judgement on this one. "E701", # Multiple statements on one line (TODO: enable) "E712", # Allow using if x == False, as it's not always equivalent to if x. "E731", # Do not use lambda - "F841", # Local assigned but not used (TODO: enable, these are likely bugs) "RUF012", # Mutable class attribute annotations "RUF034", # Useless if-else (TODO: enable) "ISC001", # single-line-implicit-string-concatenation incompatible with formatter diff --git a/setup.py b/setup.py index 8c84b2c73..a04630b8a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from distutils.errors import DistutilsModuleError +from warnings import warn + from setuptools import find_packages, setup +from setuptools.command.build_py import build_py from setuptools.dist import Distribution @@ -12,4 +16,26 @@ def has_ext_modules(self): return True -setup(version="0.47.0.dev0", packages=find_packages(), distclass=BinaryDistribution) +class ExtBuildPy(build_py): + def run(self): + # build_cmake needs to be called prior to build_py, as the latter + # collects the files output into the package directory. + try: + self.run_command("build_cmake") + except DistutilsModuleError: + warn( + "scikit-build-core not installed, CMake will not be invoked automatically. " + "Please install scikit-build-core or run CMake manually to build extensions." + ) + super().run() + + +setup( + version="0.48.0.dev0", + packages=find_packages(), + distclass=BinaryDistribution, + cmake_source_dir=".", + cmdclass={ + "build_py": ExtBuildPy, + }, +) diff --git a/tests/conftest.py b/tests/conftest.py index a514e1284..f69b9ff2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,6 +34,8 @@ def pytest_runtest_teardown(item, nextitem): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + torch.mps.empty_cache() @pytest.fixture(scope="session") diff --git a/tests/helpers.py b/tests/helpers.py index a87bc5d08..f1fa7eb62 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -18,12 +18,13 @@ @functools.cache -def get_available_devices(): +def get_available_devices(no_cpu=False): if "BNB_TEST_DEVICE" in os.environ: # If the environment variable is set, use it directly. - return [os.environ["BNB_TEST_DEVICE"]] + device = os.environ["BNB_TEST_DEVICE"] + return [] if no_cpu and device == "cpu" else [device] - devices = [] if HIP_ENVIRONMENT else ["cpu"] + devices = [] if HIP_ENVIRONMENT else ["cpu"] if not no_cpu else [] if hasattr(torch, "accelerator"): # PyTorch 2.6+ - determine accelerator using agnostic API. diff --git a/tests/test_functional.py b/tests/test_functional.py index b84db6502..6a4f72190 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,15 +1,16 @@ import math +import platform import random import time import einops -import numpy as np +from packaging import version import pytest import torch import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH +from bitsandbytes.cextension import HIP_ENVIRONMENT from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -101,16 +102,16 @@ class Test8BitBlockwiseQuantizeFunctional: def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): iters = 100 - if device == "cpu": + if device != "cuda": iters = 10 - # This test is slow on CPU, so avoid atypical use cases. + # This test is slow in our non-CUDA implementations, so avoid atypical use cases. if nested: pytest.skip("Not a typical use case.") if blocksize != 256: - pytest.skip("Only blocksize 256 is used in CPU/XPU") + pytest.skip("Only blocksize 256 is used in CPU/MPS/XPU") if dtype != torch.float32: - pytest.skip("Only float32 is used in CPU/XPU") + pytest.skip("Only float32 is used in CPU/MPS/XPU") diffs = [] reldiffs = [] @@ -142,11 +143,11 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, abserr = sum(diffs) / len(diffs) relerr = sum(reldiffs) / len(reldiffs) if signed: - threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035 + threshold_abserr = 0.0035 assert abserr < 0.0036 assert relerr < 0.015 else: - assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023 + assert abserr < 0.0023 assert relerr < 0.012 assert A2.dtype == dtype @@ -177,8 +178,8 @@ def test_blockwise_cpu_large(self, hidden, blocksize): @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"]) def test_few_bit_quant(self, device, bits, method): - if bits != 8 and (device == "cpu" or (device == "xpu" and F.ipex_xpu)): - pytest.skip("CPU/XPU implementation only supports 8 bits") + if bits != 8 and device == "cpu": + pytest.skip("CPU implementation only supports 8 bits") abserrs = [] relerrs = [] @@ -239,7 +240,7 @@ def test_fp8_quant(self, device): abserr = [] relerr = [] - for i in range(100): + for i in range(10): A1 = torch.randn(1024, 1024, device=device) C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) @@ -253,7 +254,7 @@ def test_fp8_quant(self, device): abserr = [] relerr = [] - for i in range(100): + for i in range(10): A1 = torch.rand(1024, 1024, device=device) C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) @@ -267,7 +268,7 @@ def test_fp8_quant(self, device): abserr = [] relerr = [] - for i in range(100): + for i in range(10): A1 = torch.randn(1024, 1024, device=device) C, SC = F.quantize_blockwise(A1) A2 = F.dequantize_blockwise(C, SC) @@ -462,6 +463,7 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim): @pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim")) @pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim")) @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose): def min_max(x): maxA = torch.amax(x, dim=2, keepdim=True) @@ -1109,6 +1111,7 @@ class TestQuantize4BitFunctional: "blocksize", [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], ) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): pytest.skip("This configuration is not supported on HPU.") @@ -1125,21 +1128,56 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): # With larger block sizes, we can expect this to blow up. # At blocksize>=1024, don't even bother looking at relerr. - if blocksize <= 64: - assert err.item() < 0.1 - assert relerr.item() < 0.28 - elif blocksize <= 256: - assert err.item() < 0.11 - assert relerr.item() < 0.30 - elif blocksize <= 512: - assert err.item() < 0.12 - assert relerr.item() < 0.31 - elif quant_type == "fp4": - # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56 - assert err.item() < 0.08 + math.log2(blocksize) * 4e-2 - else: - # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 - assert err.item() < math.log2(blocksize) * 8e-2 + # + # Actually, the above is not true anymore after fixing the integer packing bug. + # The following values were taken from averaging 1k samples per test configuration after fixing the bug. + error_dict = dict() + error_dict["fp4"] = dict() + error_dict["nf4"] = dict() + error_dict["fp4"]["err"] = { + 64: 0.096545, + 128: 0.102947, + 256: 0.108685, + 512: 0.114087, + 1024: 0.119312, + 2048: 0.124460, + 4096: 0.129573, + } + error_dict["fp4"]["rel_err"] = { + 64: 0.260130, + 128: 0.275734, + 256: 0.289842, + 512: 0.302852, + 1024: 0.314982, + 2048: 0.326402, + 4096: 0.337228, + } + + error_dict["nf4"]["err"] = { + 64: 0.072792, + 128: 0.076835, + 256: 0.080326, + 512: 0.083535, + 1024: 0.086603, + 2048: 0.089592, + 4096: 0.092537, + } + error_dict["nf4"]["rel_err"] = { + 64: 0.203299, + 128: 0.215252, + 256: 0.226044, + 512: 0.236021, + 1024: 0.245365, + 2048: 0.254146, + 4096: 0.262457, + } + + # Allow higher tolerance for fp32 on CPU with larger block sizes + reltol = 2.8e-3 if dtype == torch.float32 and blocksize >= 128 and device == "cpu" else 1e-3 + errtol = 1.2e-3 if dtype == torch.float32 and blocksize >= 1024 and device == "cpu" else 1e-3 + + assert err < error_dict[quant_type]["err"][blocksize] + errtol + assert relerr < error_dict[quant_type]["rel_err"][blocksize] + reltol @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -1238,8 +1276,8 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double max_errs3 = [] # Large number of iterations is excessive and slow on CPU. - # Keep for CUDA for now. - iters = 100 if device == "cuda" else 10 + # Keep for CUDA/XPU for now. + iters = 10 if device == "cpu" else 100 for i in range(iters): if kind == "fc1": @@ -1341,13 +1379,13 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert err1 < 6e-5 assert relerr1 < 2e-4 assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.005 and relratio > 0.995 - assert maxratio < 1.005 and maxratio > 0.995 + assert relratio < 1.005 and relratio > 0.992 + assert maxratio < 1.005 and maxratio > 0.992 elif dtype == torch.float32: if dim <= 512: assert err1 < 5e-8 assert relerr1 < 1e-6 - assert maxerr1 < 1e-7 + assert maxerr1 < 1.05e-7 else: assert err1 < 5e-8 assert relerr1 < 8e-6 @@ -1357,34 +1395,34 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert maxratio < 1.005 and maxratio > 0.995 elif dtype == torch.bfloat16: if dim <= 512: + relerr_thres = 0.013 if hasattr(torch, "xpu") and torch.xpu.is_available() else 0.007 assert err1 < 6e-4 - assert relerr1 < 0.007 + assert relerr1 < relerr_thres assert maxerr1 < 0.015 else: assert err1 < 2e-4 assert relerr1 < 0.002 assert maxerr1 < 0.0012 assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.04 and relratio > 0.96 - assert maxratio < 1.02 and maxratio > 0.98 + assert relratio < 1.05 and relratio > 0.96 + assert maxratio < 1.05 and maxratio > 0.97 @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) - @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) - @pytest.mark.skipif( - HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", - reason="this test is not supported on ROCm with gfx90a architecture yet", - ) - def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): - if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3): - pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3") - + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") + def test_gemv_eye_4bit(self, device, storage_type, dtype): if device == "hpu" and not is_supported_on_hpu(storage_type, dtype): pytest.skip("This configuration is not supported on HPU.") - dims = 10 - torch.random.manual_seed(np.random.randint(0, 412424242)) + if ( + device == "cpu" + and platform.system() == "Windows" + and version.parse(torch.__version__).release == (2, 8, 0) + ): + pytest.skip("Regression: CPU crash on Windows with torch 2.8.0") + + dims = 4 dims = get_test_dims(0, 8192, n=dims) dims = [dim + (64 - (dim % 64)) for dim in dims] # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: @@ -1392,7 +1430,7 @@ def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device=device) B = torch.eye(dim, dtype=dtype, device=device) - qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) + qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=False) C3 = torch.matmul(A, B.t()) C2 = bnb.matmul_4bit(A, qB.t(), state) A.requires_grad = True diff --git a/tests/test_generation.py b/tests/test_generation.py index 38b5ce9bd..3ab1cc5bd 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -112,7 +112,7 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype): assert len(outputs) == n_cases failure_count = 0 for i in range(n_cases): - if not outputs[i][: len(str(math.pi))] == str(math.pi): + if outputs[i][: len(str(math.pi))] != str(math.pi): failure_count += 1 failure_max = 2 if fixture_config[0] == "huggyllama/llama-7b" else 4 if failure_count > failure_max: diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index e07b54d2d..1c5e77a32 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +def test_params4bit_torch_chunk_split(device, quant_type): + """Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility.""" + if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8): + pytest.skip("This configuration is not supported on HPU.") + + if device == "cpu": + pytest.skip("CPU quantization causes segfault, skipping CPU test") + + original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu") + + params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False) + + if device != "cpu": + params4bit = params4bit.to(device) + + chunks = torch.chunk(params4bit, 2, dim=0) + + assert isinstance(chunks, tuple), "torch.chunk should return tuple" + for chunk in chunks: + assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass" + assert hasattr(chunk, "quant_type"), "Should preserve metadata" + assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value" + + splits = torch.split(params4bit, 2, dim=0) + + assert isinstance(splits, tuple), "torch.split should return tuple" + assert len(splits) > 0, "Should have at least one split" + for split in splits: + assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass" + assert hasattr(split, "quant_type"), "Should preserve metadata" + assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value" + + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 86726bd44..51b4cf9cd 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -9,6 +9,7 @@ import torch import bitsandbytes as bnb +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.nn.modules import Linear8bitLt from tests.helpers import ( TRUE_FALSE, @@ -233,6 +234,7 @@ def test_linear8bit_serialization(linear8bit): @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): if device == "cuda" and platform.system() == "Windows": pytest.skip("Triton is not officially supported on Windows") @@ -272,14 +274,11 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): # Test with gradients. Currently only works with threshold=0. # Has a strange regression on Linux aarch64 CPU in torch==2.6.0. - # There is also an issue with torch==2.7.0 on x86-64 with IPEX. is_broken_platform = ( device == "cpu" and platform.system() == "Linux" - and ( - (platform.machine() == "aarch64" and (2, 6) <= torch.__version__ < (2, 7)) - or (platform.machine() == "x86_64" and bnb.functional.ipex_cpu) - ) + and platform.machine() == "aarch64" + and (2, 6) <= torch.__version__ < (2, 7) ) if threshold == 0 and not is_broken_platform: diff --git a/tests/test_modules.py b/tests/test_modules.py index 8946522d3..e5682e5c8 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -143,9 +143,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half() @@ -156,9 +155,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device) @@ -167,9 +165,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -189,9 +186,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -211,9 +207,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 diff --git a/tests/test_ops.py b/tests/test_ops.py index 8aa0560fd..02472630e 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -5,7 +5,6 @@ import bitsandbytes from bitsandbytes.cextension import HIP_ENVIRONMENT -from bitsandbytes.functional import ipex_xpu from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu # torch.library.opcheck is only available in torch 2.4 and later. @@ -145,10 +144,6 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): assert out.dtype == dtype assert out.device == A.device - # TODO: Enable it - if device == "xpu" and ipex_xpu: - pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check") - opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype)) @@ -216,6 +211,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") diff --git a/tests/test_optim.py b/tests/test_optim.py index 75e5a1714..3d4157152 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -11,7 +11,8 @@ import bitsandbytes as bnb import bitsandbytes.functional as F -from tests.helpers import describe_dtype, id_formatter +from bitsandbytes.utils import sync_gpu +from tests.helpers import describe_dtype, get_available_devices, id_formatter # import apex @@ -168,15 +169,23 @@ def rm_path(path): @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) -def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): +@pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device")) +@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device") +def test_optimizer32bit(dim1, dim2, gtype, optim_name, device): + if device not in ["cuda", "xpu"]: + pytest.skip("Optimizers are only supported on CUDA and XPU") + if optim_name.startswith("paged_") and sys.platform == "win32": pytest.skip("Paged optimizers can have issues on Windows.") + if optim_name.startswith("paged_") and device == "xpu": + pytest.skip("Paged optimizers are not supported on XPU currently.") + if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]: pytest.skip() if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 p2 = p1.clone() p1 = p1.float() @@ -191,7 +200,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): atol, rtol = 1e-4, 1e-3 for i in range(k): - g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -201,14 +210,14 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): for name1, name2 in str2statenames[optim_name]: torch.testing.assert_close( torch_optimizer.state[p1][name1], - bnb_optimizer.state[p2][name2].cuda(), + bnb_optimizer.state[p2][name2].to(device), atol=atol, rtol=rtol, ) # since Lion can have pretty noisy updates where things lie at the boundary - # allow up to 10 errors for Lion - assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) + # allow up to 15 errors for Lion + assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=15) if i % (k // 5) == 0 and i > 0: path = get_temp_dir() @@ -247,7 +256,12 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) -def test_global_config(requires_cuda, dim1, dim2, gtype): +@pytest.mark.parametrize("device", get_available_devices(no_cpu=True)) +@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device") +def test_global_config(dim1, dim2, gtype, device): + if device not in ["cuda", "xpu"]: + pytest.skip("Optimizers are only supported on CUDA and XPU") + if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 @@ -263,9 +277,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) - p1 = p1.cuda() - p2 = p2.cuda() - p3 = p3.cuda() + p1 = p1.to(device) + p2 = p2.to(device) + p3 = p3.to(device) adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps) @@ -275,9 +289,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): atol, rtol = 1e-4, 1e-3 for i in range(50): - g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 - g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 - g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 + g1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001 + g2 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001 + g3 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001 p1.grad = g1 p2.grad = g2 p3.grad = g3 @@ -302,13 +316,18 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) -def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): +@pytest.mark.parametrize("device", get_available_devices(no_cpu=True)) +@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device") +def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): + if device not in ["cuda", "xpu"]: + pytest.skip("8-bit optimizers are only supported on CUDA and XPU") + torch.set_printoptions(precision=6) if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 p2 = p1.clone() p1 = p1.float() blocksize = 256 @@ -330,15 +349,15 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): relerrors = [] for i in range(50): - g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() - bnb_optimizer.step() torch_optimizer.step() + bnb_optimizer.step() # since Lion can have pretty noisy updates where things lie at the boundary - assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0) + # assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0) dequant_states = [] for name1, name2, qmap, max_val in str2statenames[optim_name]: @@ -368,7 +387,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): ) num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 - # assert num_not_close.sum().item() < 20 + assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) err = torch.abs(p1 - p2) @@ -549,25 +568,25 @@ def test_adam_percentile_clipping(requires_cuda, dim1, dim2, gtype, optim_bits): @pytest.mark.parametrize("gtype", [torch.float32, torch.bfloat16, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt")) @pytest.mark.benchmark -def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): +def test_benchmark_blockwise(dim1, dim2, gtype, optim_name, device): if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 bnb_optimizer = str2optimizers[optim_name][1]([p1]) - g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01 p1.grad = g total_steps = 500 for i in range(total_steps): if i == total_steps // 5: # 100 iterations for burn-in - torch.cuda.synchronize() + sync_gpu(p1) t0 = time.time() bnb_optimizer.step() - torch.cuda.synchronize() + sync_gpu(p1) s = time.time() - t0 print("") params = (total_steps - total_steps // 5) * dim1 * dim2 diff --git a/tests/test_parametrize.py b/tests/test_parametrize.py new file mode 100644 index 000000000..9e661ee2f --- /dev/null +++ b/tests/test_parametrize.py @@ -0,0 +1,411 @@ +import pytest +import torch +import torch.nn as nn + +from bitsandbytes import functional as F +from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.nn.parametrize import ( + Bnb4bitParametrization, + replace_parameter_4bit, + replace_parameter_4bit_prequantized, +) +from tests.helpers import ( + TRUE_FALSE, + describe_dtype, + get_available_devices, + id_formatter, + is_supported_on_hpu, +) + + +class ParametrizeTestModule(nn.Module): + """Test module with different parameter shapes for testing parametrization.""" + + def __init__(self, device="cpu", dtype=torch.float32): + super().__init__() + # 2D parameter (typical weight matrix) + self.weight_2d = nn.Parameter(torch.randn(1024, 1024, device=device, dtype=dtype)) + # 3D parameter (MoE expert weights - the main use case for this feature) + self.expert_weights = nn.Parameter(torch.randn(8, 512, 256, device=device, dtype=dtype)) + # 1D parameter (bias-like) + self.bias_1d = nn.Parameter(torch.randn(1024, device=device, dtype=dtype)) + # Non-parameter attribute (should not be quantizable) + self.not_param = torch.randn(32, device=device, dtype=dtype) + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +@pytest.mark.parametrize( + "blocksize", + [64, 128, 256] if not HIP_ENVIRONMENT else [128, 256], +) +def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics, blocksize): + """Test basic parameter replacement with 4-bit quantization on different dtypes.""" + if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): + pytest.skip("This configuration is not supported on HPU.") + + # Create module directly on target device to avoid unnecessary transfers + module = ParametrizeTestModule(device=device, dtype=dtype) + original_param = module.weight_2d.clone() + + # Apply 4-bit quantization parametrization to the weight parameter + replace_parameter_4bit( + module, "weight_2d", compress_statistics=compress_statistics, quant_type=quant_type, blocksize=blocksize + ) + + # Verify that parametrization was applied correctly + assert hasattr(module, "parametrizations"), "Module should have parametrizations attribute" + assert "weight_2d" in module.parametrizations, "weight_2d should be parametrized" + + # Test that accessing the parameter returns dequantized version with correct properties + reconstructed = module.weight_2d + assert reconstructed.shape == original_param.shape, "Shape should be preserved" + assert reconstructed.dtype == dtype, "dtype should match original" + assert reconstructed.device.type == device, "Device should match target" + + # Verify quantization quality using same approach as functional tests + err = (original_param - reconstructed.detach()).abs().float() + relerr = (err / (original_param.abs().float() + 1e-8)).mean() + err_mean = err.mean() + + # Expected error bounds from test_functional.py + expected_errors = { + "nf4": { + 64: {"abs": 0.072792, "rel": 0.203299}, + 128: {"abs": 0.076835, "rel": 0.215252}, + 256: {"abs": 0.080326, "rel": 0.226044}, + }, + "fp4": { + 64: {"abs": 0.096545, "rel": 0.260130}, + 128: {"abs": 0.102947, "rel": 0.275734}, + 256: {"abs": 0.108685, "rel": 0.289842}, + }, + } + + assert err_mean < expected_errors[quant_type][blocksize]["abs"] + 1e-3, f"Mean abs error {err_mean:.6f} too high" + assert relerr < expected_errors[quant_type][blocksize]["rel"] + 1e-3, f"Mean rel error {relerr:.6f} too high" + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +def test_moe_parameter_shape(device, dtype): + """Test parametrization with MoE-style parameter shape""" + if device == "hpu" and not is_supported_on_hpu("nf4", dtype): + pytest.skip("This configuration is not supported on HPU.") + + param_shape = (8, 64, 32) + + # Create module with custom parameter shape directly on target device + class MoEModule(nn.Module): + def __init__(self, device, dtype): + super().__init__() + self.param = nn.Parameter(torch.randn(*param_shape, dtype=dtype, device=device)) + + module = MoEModule(device=device, dtype=dtype) + original_param = module.param.clone() + + # Apply quantization parametrization + replace_parameter_4bit(module, "param", quant_type="nf4") + + # Verify reconstruction maintains all properties + reconstructed = module.param + assert reconstructed.shape == param_shape, f"Shape should be preserved: {reconstructed.shape} vs {param_shape}" + assert reconstructed.dtype == dtype, "dtype should match original" + assert reconstructed.device.type == device, "Device should match target" + + # Verify quantization quality using error calculation approach from functional tests + err = (original_param - reconstructed.detach()).abs().float() + relerr = (err / (original_param.abs().float() + 1e-8)).mean() + err_mean = err.mean() + + # Use slightly looser bounds for higher dimensional tensors + abs_bound = 0.085 # NF4 baseline + margin + rel_bound = 0.25 # NF4 baseline + margin + + assert err_mean < abs_bound, f"Mean abs error {err_mean:.6f} too high for shape {param_shape}" + assert relerr < rel_bound, f"Mean rel error {relerr:.6f} too high for shape {param_shape}" + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +def test_prequantized_replacement(device, dtype, quant_type): + """Test applying parametrization to already quantized parameters.""" + if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): + pytest.skip("Configuration not supported on HPU.") + + module = ParametrizeTestModule(device=device, dtype=dtype) + original_param = module.weight_2d.clone() + + # Manually quantize the parameter data first (simulates loading pre-quantized weights) + quantized_data, quant_state = F.quantize_4bit(original_param.data, quant_type=quant_type) + + # Replace parameter with quantized data (what would happen during model loading) + module.weight_2d = nn.Parameter(quantized_data, requires_grad=False) + + # Apply parametrization to handle dequantization on access + replace_parameter_4bit_prequantized( + module, "weight_2d", quant_state.as_dict(packed=True), device=torch.device(device) + ) + + # Test that parameter access properly dequantizes + reconstructed = module.weight_2d + assert reconstructed.shape == original_param.shape, "Shape should be preserved" + assert reconstructed.dtype == dtype, "dtype should match original" + assert reconstructed.device.type == device, "Device should match target" + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +@pytest.mark.skipif(torch.__version__ < (2, 5), reason="state dict hook requires torch >= 2.5.0") +def test_state_dict_functionality(device, dtype, quant_type, compress_statistics): + """Test that state dict saving works with quantized parameters.""" + if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): + pytest.skip("Configuration not supported on HPU.") + + module = ParametrizeTestModule(device=device, dtype=dtype) + + # Apply parametrization to expert weights (main MoE use case) + replace_parameter_4bit(module, "expert_weights", quant_type=quant_type, compress_statistics=compress_statistics) + + # Save state dict - should include quantization state, not parametrization internals + state_dict = module.state_dict() + + # Verify state dict structure: quantized param + quantization metadata + assert "expert_weights" in state_dict, "Quantized parameter should be in state dict" + assert "expert_weights.absmax" in state_dict, "Quantization absmax should be saved" + assert "expert_weights.quant_map" in state_dict, "Quantization map should be saved" + assert f"expert_weights.quant_state.bitsandbytes__{quant_type}" in state_dict, "Quant state should be saved" + + # Verify parametrization internals are NOT saved (clean state dict) + assert "parametrizations.expert_weights.original" not in state_dict, ( + "Internal parametrization keys should not be saved" + ) + + # Test that the parameter can be accessed after state dict creation + reconstructed = module.expert_weights + assert reconstructed.shape == (8, 512, 256), "Shape should be preserved" + assert reconstructed.dtype == dtype, "dtype should match" + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +def test_moe_realistic_forward(device, dtype): + """Test realistic MoE forward computation with quantized expert weights.""" + if device == "hpu" and not is_supported_on_hpu("nf4", dtype): + pytest.skip("Configuration not supported on HPU.") + + class SimpleMoE(nn.Module): + def __init__(self, device, dtype): + super().__init__() + # Expert weights: [num_experts, input_dim, output_dim] + self.expert_weights = nn.Parameter(torch.randn(4, 32, 64, dtype=dtype, device=device)) + + def forward(self, x, expert_idx=0): + # Select and use specific expert weight matrix + expert_weight = self.expert_weights[expert_idx] # Shape: [input_dim, output_dim] + return torch.matmul(x, expert_weight) + + module = SimpleMoE(device=device, dtype=dtype) + x = torch.randn(8, 32, dtype=dtype, device=device) + + # Get reference output before quantization + with torch.no_grad(): + reference_output = module(x, expert_idx=1) + + # Apply 4-bit quantization to expert weights + replace_parameter_4bit(module, "expert_weights", quant_type="nf4") + + # Get output after quantization - should be very close to original + with torch.no_grad(): + quantized_output = module(x, expert_idx=1) + + # Verify outputs match within quantization tolerance + assert quantized_output.shape == reference_output.shape, "Output shape should be preserved" + + # Calculate error like functional tests (matrix ops may amplify quantization errors) + err = (reference_output - quantized_output).abs().float() + relerr = (err / (reference_output.abs().float() + 1e-8)).mean() + err_mean = err.mean() + + # Allow for error amplification through matrix multiplication + assert err_mean < 0.5, f"Forward pass mean abs error {err_mean:.6f} too high" + assert relerr < 2.0, f"Forward pass mean rel error {relerr:.6f} too high" + + +def test_error_conditions(): + """Test that proper errors are raised for invalid inputs.""" + module = ParametrizeTestModule() + + # Test AttributeError for non-existent parameter + with pytest.raises(AttributeError, match="Module does not have parameter 'nonexistent'"): + replace_parameter_4bit(module, "nonexistent") + + # Test TypeError for non-Parameter attribute + with pytest.raises(TypeError, match="Parameter 'not_param' is not an instance of nn.Parameter"): + replace_parameter_4bit(module, "not_param") + + # Test same errors for prequantized version + with pytest.raises(AttributeError, match="Module does not have parameter 'nonexistent'"): + replace_parameter_4bit_prequantized(module, "nonexistent", {}, torch.device("cpu")) + + with pytest.raises(TypeError, match="Parameter 'not_param' is not an instance of nn.Parameter"): + replace_parameter_4bit_prequantized(module, "not_param", {}, torch.device("cpu")) + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.skipif(torch.__version__ < (2, 5), reason="state dict hook requires torch >= 2.5.0") +def test_quant_state_preservation(device, dtype): + """Test that quantization state is properly preserved and accessible.""" + if device == "hpu" and not is_supported_on_hpu("nf4", dtype): + pytest.skip("Configuration not supported on HPU.") + + module = ParametrizeTestModule(device=device, dtype=dtype) + + blocksize = 128 if HIP_ENVIRONMENT else 64 + + # Apply parametrization with specific settings + replace_parameter_4bit(module, "weight_2d", quant_type="nf4", compress_statistics=True, blocksize=blocksize) + + # Verify that quantization state is accessible through parametrization + parametrization = module.parametrizations.weight_2d[0] + assert isinstance(parametrization, Bnb4bitParametrization), "Should be Bnb4bitParametrization instance" + + # Check quantization state properties + quant_state = parametrization.quant_state + assert isinstance(quant_state, F.QuantState), "Should have QuantState" + assert quant_state.quant_type == "nf4", "Quant type should be preserved" + assert quant_state.blocksize == blocksize, "Block size should be preserved" + + # Verify that state dict includes all necessary quantization metadata + state_dict = module.state_dict() + quant_state_dict = quant_state.as_dict(packed=True) + + for key in quant_state_dict.keys(): + full_key = f"weight_2d.{key}" + assert full_key in state_dict, f"Quantization metadata '{full_key}' should be in state dict" + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.skipif(torch.__version__ < (2, 5), reason="state dict hook requires torch >= 2.5.0") +def test_multiple_parameters(device, dtype): + """Test applying parametrization to multiple parameters in the same module.""" + if device == "hpu" and not is_supported_on_hpu("nf4", dtype): + pytest.skip("Configuration not supported on HPU.") + + module = ParametrizeTestModule(device=device, dtype=dtype) + original_2d = module.weight_2d.clone() + original_3d = module.expert_weights.clone() + + # Apply parametrization to multiple parameters, with varying configurations + replace_parameter_4bit(module, "weight_2d", quant_type="nf4", blocksize=128) + replace_parameter_4bit(module, "expert_weights", quant_type="fp4", blocksize=256) + + # Verify both parameters are parametrized and work correctly + reconstructed_2d = module.weight_2d + reconstructed_3d = module.expert_weights + + assert reconstructed_2d.shape == original_2d.shape, "2D parameter shape should be preserved" + assert reconstructed_3d.shape == original_3d.shape, "3D parameter shape should be preserved" + + # Check that state dict includes quantization info for both parameters + state_dict = module.state_dict() + assert "weight_2d" in state_dict, "2D parameter should be in state dict" + assert "expert_weights" in state_dict, "3D parameter should be in state dict" + assert "weight_2d.absmax" in state_dict, "2D parameter quantization metadata should be saved" + assert "expert_weights.absmax" in state_dict, "3D parameter quantization metadata should be saved" + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize( + "blocksize", + [64, 128, 256] if not HIP_ENVIRONMENT else [128, 256], +) +def test_different_blocksizes(device, dtype, blocksize): + """Test parametrization with different block sizes to verify flexibility.""" + if device == "hpu" and not is_supported_on_hpu("nf4", dtype): + pytest.skip("Configuration not supported on HPU.") + + module = ParametrizeTestModule(device=device, dtype=dtype) + original_param = module.expert_weights.clone() + + # Apply parametrization with specified block size + replace_parameter_4bit(module, "expert_weights", quant_type="nf4", blocksize=blocksize) + + # Verify reconstruction works with different block sizes + reconstructed = module.expert_weights + assert reconstructed.shape == original_param.shape, "Shape should be preserved" + assert reconstructed.device.type == device, "Device should match" + + # Verify quantization quality using error calculation approach from functional tests + err = (original_param - reconstructed.detach()).abs().float() + relerr = (err / (original_param.abs().float() + 1e-8)).mean() + err_mean = err.mean() + + # Expected error bounds from functional tests (using NF4 bounds since that's what we're testing) + expected_abs = {64: 0.072792, 128: 0.076835, 256: 0.080326} + expected_rel = {64: 0.203299, 128: 0.215252, 256: 0.226044} + + assert err_mean < expected_abs[blocksize] + 0.01, ( + f"Mean abs error {err_mean:.6f} too high for blocksize {blocksize}" + ) + assert relerr < expected_rel[blocksize] + 0.02, f"Mean rel error {relerr:.6f} too high for blocksize {blocksize}" + + +def test_parametrization_forward_method(): + """Test the Bnb4bitParametrization forward method directly.""" + device = "cpu" + + # Create test tensor and manually quantize it + original_tensor = torch.randn(64, 32, dtype=torch.float32, device=device) + quantized_data, quant_state = F.quantize_4bit(original_tensor, quant_type="nf4") + + # Create parametrization instance + parametrization = Bnb4bitParametrization(quant_state) + + # Test forward pass (dequantization) + dequantized = parametrization.forward(quantized_data) + + # Verify dequantization produces correct output + assert dequantized.shape == original_tensor.shape, "Shape should be preserved during dequantization" + assert dequantized.dtype == torch.float32, "dtype should be preserved" + assert dequantized.device == original_tensor.device, "Device should be preserved" + + # Check that dequantization approximates original using mean error calculation + err = (original_tensor - dequantized.detach()).abs().float() + relerr = (err / (original_tensor.abs().float() + 1e-8)).mean() + err_mean = err.mean() + + # Use NF4 bounds from functional tests with small margin + assert err_mean < 0.08, f"Mean abs error {err_mean:.6f} too high" + assert relerr < 0.25, f"Mean rel error {relerr:.6f} too high" + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +def test_gradient_behavior(device, dtype): + """Test that quantized parameters have proper gradient behavior.""" + if device == "hpu" and not is_supported_on_hpu("nf4", dtype): + pytest.skip("Configuration not supported on HPU.") + + module = ParametrizeTestModule(device=device, dtype=dtype) + + # Ensure original parameter requires gradients + module.weight_2d.requires_grad_(True) + assert module.weight_2d.requires_grad, "Original parameter should require gradients" + + # Apply quantization parametrization + replace_parameter_4bit(module, "weight_2d", quant_type="nf4") + + # Verify that quantized parameters don't require gradients (expected behavior) + # The underlying quantized parameter should have requires_grad=False + # The dequantized output should also not require gradients + reconstructed = module.weight_2d + assert not reconstructed.requires_grad, "Dequantized parameter should not require gradients" From fcbec797bba6cf1b0ca8c01149a35004dd6d4610 Mon Sep 17 00:00:00 2001 From: pnunna93 <104791500+pnunna93@users.noreply.github.com> Date: Tue, 23 Sep 2025 22:50:34 -0500 Subject: [PATCH 101/102] Fix typo --- .github/scripts/build-cuda.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/scripts/build-cuda.sh b/.github/scripts/build-cuda.sh index a51dda5e6..9eed06896 100644 --- a/.github/scripts/build-cuda.sh +++ b/.github/scripts/build-cuda.sh @@ -20,7 +20,6 @@ else # CUDA 12.8+: Add sm100 and sm120; remove < sm70 to align with PyTorch 2.8+cu128 minimum [[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="70;75;80;86;89;90;100;120" - # CUDA 13.0+: Remove < sm75 to align with PyTorch 2.9+cu130 minimum [[ "${cuda_version}" == 13.*.* ]] && build_capability="75;80;86;89;90;100;120" fi From 4fa939b3883ca17574333de2935beaabf71b2dba Mon Sep 17 00:00:00 2001 From: pnunna93 <104791500+pnunna93@users.noreply.github.com> Date: Wed, 24 Sep 2025 13:45:06 -0500 Subject: [PATCH 102/102] unskip test_4bit_quant --- tests/test_functional.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 6a4f72190..072e3b4f5 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1111,7 +1111,6 @@ class TestQuantize4BitFunctional: "blocksize", [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], ) - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): pytest.skip("This configuration is not supported on HPU.")