Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
# Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90;100;120`
# Check your compute capability here: https://developer.nvidia.com/cuda-gpus
# - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler
# - ROCM_VERSION: Override the ROCm version shortcode used in the output library name.
# Useful when PyTorch was built against a different ROCm version than the
# system install. For example, `-DROCM_VERSION=70` produces
# libbitsandbytes_rocm70.so even if the system has ROCm 7.2.
cmake_minimum_required(VERSION 3.22.1)

project(bitsandbytes LANGUAGES CXX)
Expand Down Expand Up @@ -222,7 +226,15 @@ elseif(BUILD_HIP)
string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}")
string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}")

string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}")
# Expose a cache variable that the user can set to override the ROCm version in the library name
set(ROCM_VERSION "${HIP_VERSION_SHORT}" CACHE STRING "Expected ROCm Version Shortcode")

message(STATUS "ROCm Version: ${HIP_VERSION_SHORT} (from hipconfig)")
if(NOT ROCM_VERSION STREQUAL "${HIP_VERSION_SHORT}")
message(WARNING "Overriding ROCm version in library name: ${HIP_VERSION_SHORT} -> ${ROCM_VERSION}")
endif()

string(APPEND BNB_OUTPUT_NAME "${ROCM_VERSION}")
add_compile_definitions(__HIP_PLATFORM_AMD__)
add_compile_definitions(__HIP_PLATFORM_HCC__)
add_compile_definitions(BUILD_HIP)
Expand Down
1 change: 1 addition & 0 deletions agents/architecture_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ GPU-specific functions are actually invoked.
### Environment variables

- `BNB_CUDA_VERSION` — Override the auto-detected CUDA version for library selection
- `BNB_ROCM_VERSION` is the ROCm equivalent
- Standard CUDA env vars (`CUDA_HOME`, `LD_LIBRARY_PATH`) affect library discovery

---
Expand Down
22 changes: 18 additions & 4 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,21 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"

override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value:
rocm_override_value = os.environ.get("BNB_ROCM_VERSION")

if rocm_override_value and torch.version.hip:
library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1)
logger.warning(
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n"
"This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n"
"If this was unintended set the BNB_ROCM_VERSION variable to an empty string: export BNB_ROCM_VERSION=\n"
)
elif override_value:
library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1)
if torch.version.hip:
raise RuntimeError(
f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n"
f"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=<version>\n"
f"Clear the variable and retry: export BNB_CUDA_VERSION=\n"
)
logger.warning(
Expand Down Expand Up @@ -122,7 +132,7 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary):
1. Missing shared library dependencies (e.g., libcudart.so not in LD_LIBRARY_PATH or through PyTorch CUDA installation)
2. CUDA version mismatch between PyTorch and available pre-compiled binaries
3. Completely missing pre-compiled binaries when CUDA is detected
4. Custom BNB_CUDA_VERSION override but mismatch
4. Custom BNB_CUDA_VERSION or BNB_ROCM_VERSION override but mismatch
5. CPU-only installation attempts when GPU functionality is requested

"""
Expand All @@ -131,7 +141,9 @@ def __init__(self, error_msg: str):
self.error_msg = error_msg
self.user_cuda_version = get_cuda_version_tuple()
self.available_versions = get_available_cuda_binary_versions()
self.override_value = os.environ.get("BNB_CUDA_VERSION")
self.override_value = (
os.environ.get("BNB_ROCM_VERSION") if HIP_ENVIRONMENT else os.environ.get("BNB_CUDA_VERSION")
)
self.requested_version = (
parse_cuda_version(self.override_value)
if self.override_value
Expand Down Expand Up @@ -217,8 +229,10 @@ def _format_lib_error_message(
)
if not HIP_ENVIRONMENT
else (
"You can COMPILE FROM SOURCE as mentioned here:\n"
"You have two options:\n"
"1. COMPILE FROM SOURCE as mentioned here:\n"
" https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n"
"2. Use BNB_ROCM_VERSION to specify a DIFFERENT ROCm version from the detected one, matching the version the library was built with.\n\n"
)
)

Expand Down
30 changes: 20 additions & 10 deletions bitsandbytes/diagnostics/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,21 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None:
print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}")

rocm_override = os.environ.get("BNB_ROCM_VERSION")
if rocm_override:
print(f"BNB_ROCM_VERSION override: {rocm_override}")

binary_path = get_cuda_bnb_library_path(cuda_specs)
if not binary_path.exists():
print_dedented(
f"""
Library not found: {binary_path}.
Maybe you need to compile it from source? If you compiled from source, check that ROCm version
in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version
and rebuild bitsandbytes.
""",
Library not found: {binary_path}.
Maybe you need to compile it from source? If you compiled from source, check that ROCm version
in PyTorch Settings matches your ROCm install. If not, you can either:
1. Reinstall PyTorch for your ROCm version and rebuild bitsandbytes.
2. Set BNB_ROCM_VERSION to match the version the library was built with.
For example: export BNB_ROCM_VERSION=72
""",
)

hip_major, hip_minor = cuda_specs.cuda_version_tuple
Expand Down Expand Up @@ -192,22 +198,26 @@ def _print_cuda_runtime_diagnostics() -> None:
def _print_hip_runtime_diagnostics() -> None:
cudart_paths = list(find_cudart_libraries())
if not cudart_paths:
print("WARNING! ROCm runtime files not found in any environmental path.")
print("ROCm SETUP: WARNING! ROCm runtime files not found in any environmental path.")
elif len(cudart_paths) > 1:
print_dedented(
f"""
Found duplicate ROCm runtime files (see below).

We select the PyTorch default ROCm runtime, which is {torch.version.hip},
but this might mismatch with the ROCm version that is needed for bitsandbytes.
To override this behavior set the `BNB_ROCM_VERSION=<version string, e.g. 72>` environmental variable.

For example, if you want to use the ROCm version 7.2,
BNB_ROCM_VERSION=72 python ...

To resolve it, install PyTorch built for the ROCm version you want to use
OR set the environmental variable in your .bashrc:
export BNB_ROCM_VERSION=72

and set LD_LIBRARY_PATH to your ROCm install path, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib,
In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-7.2.0/lib,
""",
)

for pth in cudart_paths:
print(f"* Found ROCm runtime at: {pth}")

Expand Down
46 changes: 46 additions & 0 deletions tests/test_cuda_setup_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,49 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?


# Simulates torch+rocm7.0 (PyTorch bundled ROCm) on a system with ROCm 7.2
@pytest.fixture
def rocm70_spec() -> CUDASpecs:
return CUDASpecs(
cuda_version_string="70", # from torch.version.hip == "7.0.x"
highest_compute_capability=(0, 0), # unused for ROCm library path resolution
cuda_version_tuple=(7, 0),
)


@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec):
"""Without override, library path uses PyTorch's ROCm 7.0 version."""
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm70"


@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
def test_get_rocm_bnb_library_path_override(monkeypatch, rocm70_spec, caplog):
"""BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0."""
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
assert "BNB_ROCM_VERSION" in caplog.text


@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec):
"""BNB_CUDA_VERSION should be rejected on ROCm with a helpful error."""
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
monkeypatch.setenv("BNB_CUDA_VERSION", "72")
with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*detected for ROCm"):
get_cuda_bnb_library_path(rocm70_spec)


@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
def test_get_rocm_bnb_library_path_rocm_override_takes_priority(monkeypatch, rocm70_spec, caplog):
"""When both are set, BNB_ROCM_VERSION wins if HIP_ENVIRONMENT is True."""
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
monkeypatch.setenv("BNB_CUDA_VERSION", "72")
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
assert "BNB_ROCM_VERSION" in caplog.text
assert "BNB_CUDA_VERSION" not in caplog.text
Loading