diff --git a/CMakeLists.txt b/CMakeLists.txt index 922b04b89..dc8e5181f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,10 @@ # Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90;100;120` # Check your compute capability here: https://developer.nvidia.com/cuda-gpus # - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler +# - ROCM_VERSION: Override the ROCm version shortcode used in the output library name. +# Useful when PyTorch was built against a different ROCm version than the +# system install. For example, `-DROCM_VERSION=70` produces +# libbitsandbytes_rocm70.so even if the system has ROCm 7.2. cmake_minimum_required(VERSION 3.22.1) project(bitsandbytes LANGUAGES CXX) @@ -222,7 +226,15 @@ elseif(BUILD_HIP) string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}") string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}") - string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}") + # Expose a cache variable that the user can set to override the ROCm version in the library name + set(ROCM_VERSION "${HIP_VERSION_SHORT}" CACHE STRING "Expected ROCm Version Shortcode") + + message(STATUS "ROCm Version: ${HIP_VERSION_SHORT} (from hipconfig)") + if(NOT ROCM_VERSION STREQUAL "${HIP_VERSION_SHORT}") + message(WARNING "Overriding ROCm version in library name: ${HIP_VERSION_SHORT} -> ${ROCM_VERSION}") + endif() + + string(APPEND BNB_OUTPUT_NAME "${ROCM_VERSION}") add_compile_definitions(__HIP_PLATFORM_AMD__) add_compile_definitions(__HIP_PLATFORM_HCC__) add_compile_definitions(BUILD_HIP) diff --git a/agents/architecture_guide.md b/agents/architecture_guide.md index b10211b74..ddf90bdbb 100644 --- a/agents/architecture_guide.md +++ b/agents/architecture_guide.md @@ -329,6 +329,7 @@ GPU-specific functions are actually invoked. ### Environment variables - `BNB_CUDA_VERSION` — Override the auto-detected CUDA version for library selection + - `BNB_ROCM_VERSION` is the ROCm equivalent - Standard CUDA env vars (`CUDA_HOME`, `LD_LIBRARY_PATH`) affect library discovery --- diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 188576225..11a5cffb7 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -32,11 +32,21 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" override_value = os.environ.get("BNB_CUDA_VERSION") - if override_value: + rocm_override_value = os.environ.get("BNB_ROCM_VERSION") + + if rocm_override_value and torch.version.hip: + library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1) + logger.warning( + f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n" + "This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n" + "If this was unintended set the BNB_ROCM_VERSION variable to an empty string: export BNB_ROCM_VERSION=\n" + ) + elif override_value: library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) if torch.version.hip: raise RuntimeError( f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" + f"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=\n" f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" ) logger.warning( @@ -122,7 +132,7 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary): 1. Missing shared library dependencies (e.g., libcudart.so not in LD_LIBRARY_PATH or through PyTorch CUDA installation) 2. CUDA version mismatch between PyTorch and available pre-compiled binaries 3. Completely missing pre-compiled binaries when CUDA is detected - 4. Custom BNB_CUDA_VERSION override but mismatch + 4. Custom BNB_CUDA_VERSION or BNB_ROCM_VERSION override but mismatch 5. CPU-only installation attempts when GPU functionality is requested """ @@ -131,7 +141,9 @@ def __init__(self, error_msg: str): self.error_msg = error_msg self.user_cuda_version = get_cuda_version_tuple() self.available_versions = get_available_cuda_binary_versions() - self.override_value = os.environ.get("BNB_CUDA_VERSION") + self.override_value = ( + os.environ.get("BNB_ROCM_VERSION") if HIP_ENVIRONMENT else os.environ.get("BNB_CUDA_VERSION") + ) self.requested_version = ( parse_cuda_version(self.override_value) if self.override_value @@ -217,8 +229,10 @@ def _format_lib_error_message( ) if not HIP_ENVIRONMENT else ( - "You can COMPILE FROM SOURCE as mentioned here:\n" + "You have two options:\n" + "1. COMPILE FROM SOURCE as mentioned here:\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n" + "2. Use BNB_ROCM_VERSION to specify a DIFFERENT ROCm version from the detected one, matching the version the library was built with.\n\n" ) ) diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index 29a9a66e1..de4d036cb 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -135,15 +135,21 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") + rocm_override = os.environ.get("BNB_ROCM_VERSION") + if rocm_override: + print(f"BNB_ROCM_VERSION override: {rocm_override}") + binary_path = get_cuda_bnb_library_path(cuda_specs) if not binary_path.exists(): print_dedented( f""" - Library not found: {binary_path}. - Maybe you need to compile it from source? If you compiled from source, check that ROCm version - in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version - and rebuild bitsandbytes. - """, + Library not found: {binary_path}. + Maybe you need to compile it from source? If you compiled from source, check that ROCm version + in PyTorch Settings matches your ROCm install. If not, you can either: + 1. Reinstall PyTorch for your ROCm version and rebuild bitsandbytes. + 2. Set BNB_ROCM_VERSION to match the version the library was built with. + For example: export BNB_ROCM_VERSION=72 + """, ) hip_major, hip_minor = cuda_specs.cuda_version_tuple @@ -192,7 +198,7 @@ def _print_cuda_runtime_diagnostics() -> None: def _print_hip_runtime_diagnostics() -> None: cudart_paths = list(find_cudart_libraries()) if not cudart_paths: - print("WARNING! ROCm runtime files not found in any environmental path.") + print("ROCm SETUP: WARNING! ROCm runtime files not found in any environmental path.") elif len(cudart_paths) > 1: print_dedented( f""" @@ -200,14 +206,18 @@ def _print_hip_runtime_diagnostics() -> None: We select the PyTorch default ROCm runtime, which is {torch.version.hip}, but this might mismatch with the ROCm version that is needed for bitsandbytes. + To override this behavior set the `BNB_ROCM_VERSION=` environmental variable. + + For example, if you want to use the ROCm version 7.2, + BNB_ROCM_VERSION=72 python ... - To resolve it, install PyTorch built for the ROCm version you want to use + OR set the environmental variable in your .bashrc: + export BNB_ROCM_VERSION=72 - and set LD_LIBRARY_PATH to your ROCm install path, e.g. - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib, + In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-7.2.0/lib, """, ) - for pth in cudart_paths: print(f"* Found ROCm runtime at: {pth}") diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 3d8b688ee..f74f05634 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -24,3 +24,49 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? + + +# Simulates torch+rocm7.0 (PyTorch bundled ROCm) on a system with ROCm 7.2 +@pytest.fixture +def rocm70_spec() -> CUDASpecs: + return CUDASpecs( + cuda_version_string="70", # from torch.version.hip == "7.0.x" + highest_compute_capability=(0, 0), # unused for ROCm library path resolution + cuda_version_tuple=(7, 0), + ) + + +@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") +def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec): + """Without override, library path uses PyTorch's ROCm 7.0 version.""" + monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) + monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) + assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm70" + + +@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") +def test_get_rocm_bnb_library_path_override(monkeypatch, rocm70_spec, caplog): + """BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0.""" + monkeypatch.setenv("BNB_ROCM_VERSION", "72") + monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) + assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72" + assert "BNB_ROCM_VERSION" in caplog.text + + +@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") +def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec): + """BNB_CUDA_VERSION should be rejected on ROCm with a helpful error.""" + monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) + monkeypatch.setenv("BNB_CUDA_VERSION", "72") + with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*detected for ROCm"): + get_cuda_bnb_library_path(rocm70_spec) + + +@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") +def test_get_rocm_bnb_library_path_rocm_override_takes_priority(monkeypatch, rocm70_spec, caplog): + """When both are set, BNB_ROCM_VERSION wins if HIP_ENVIRONMENT is True.""" + monkeypatch.setenv("BNB_ROCM_VERSION", "72") + monkeypatch.setenv("BNB_CUDA_VERSION", "72") + assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72" + assert "BNB_ROCM_VERSION" in caplog.text + assert "BNB_CUDA_VERSION" not in caplog.text