From 6074e0e45c069af8ddb5a61e44e45ec7f91b3f47 Mon Sep 17 00:00:00 2001 From: V-E-D Date: Wed, 16 Apr 2025 14:25:35 +0530 Subject: [PATCH 1/3] fix: Improve CUDA version detection and error handling --- bitsandbytes/cuda_specs.py | 57 ++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index ff90c50ed..27e2bda3d 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,6 +1,6 @@ import dataclasses from functools import lru_cache -from typing import Optional +from typing import Optional, Tuple import torch @@ -21,26 +21,55 @@ def get_compute_capabilities() -> list[tuple[int, int]]: @lru_cache(None) -def get_cuda_version_tuple() -> tuple[int, int]: - if torch.version.cuda: - return tuple(map(int, torch.version.cuda.split(".")[0:2])) - elif torch.version.hip: - return tuple(map(int, torch.version.hip.split(".")[0:2])) +def get_cuda_version_tuple() -> Optional[Tuple[int, int]]: + """Get CUDA/HIP version as a tuple of (major, minor).""" + try: + if torch.version.cuda: + version_str = torch.version.cuda + elif torch.version.hip: + version_str = torch.version.hip + else: + return None - return None + parts = version_str.split(".") + if len(parts) >= 2: + return tuple(map(int, parts[:2])) + return None + except (AttributeError, ValueError, IndexError): + return None -def get_cuda_version_string() -> str: - major, minor = get_cuda_version_tuple() +def get_cuda_version_string() -> Optional[str]: + """Get CUDA/HIP version as a string.""" + version_tuple = get_cuda_version_tuple() + if version_tuple is None: + return None + major, minor = version_tuple return f"{major * 10 + minor}" def get_cuda_specs() -> Optional[CUDASpecs]: + """Get CUDA/HIP specifications.""" if not torch.cuda.is_available(): return None - return CUDASpecs( - highest_compute_capability=(get_compute_capabilities()[-1]), - cuda_version_string=(get_cuda_version_string()), - cuda_version_tuple=get_cuda_version_tuple(), - ) + try: + compute_capabilities = get_compute_capabilities() + if not compute_capabilities: + return None + + version_tuple = get_cuda_version_tuple() + if version_tuple is None: + return None + + version_string = get_cuda_version_string() + if version_string is None: + return None + + return CUDASpecs( + highest_compute_capability=compute_capabilities[-1], + cuda_version_string=version_string, + cuda_version_tuple=version_tuple, + ) + except Exception: + return None From 28d4d5c776ba07f54804fae5e909edbefaea6f15 Mon Sep 17 00:00:00 2001 From: V-E-D Date: Wed, 16 Apr 2025 21:18:08 +0530 Subject: [PATCH 2/3] lint fix --- bitsandbytes/cuda_specs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 27e2bda3d..285872e4f 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -73,3 +73,4 @@ def get_cuda_specs() -> Optional[CUDASpecs]: ) except Exception: return None + From 1724fa14722c5d860600c71764b6487a0aa14122 Mon Sep 17 00:00:00 2001 From: V-E-D Date: Wed, 16 Apr 2025 21:22:03 +0530 Subject: [PATCH 3/3] lint fix --- bitsandbytes/cuda_specs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 285872e4f..64903cd49 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,6 +1,6 @@ import dataclasses from functools import lru_cache -from typing import Optional, Tuple +from typing import Optional import torch @@ -21,7 +21,7 @@ def get_compute_capabilities() -> list[tuple[int, int]]: @lru_cache(None) -def get_cuda_version_tuple() -> Optional[Tuple[int, int]]: +def get_cuda_version_tuple() -> Optional[tuple[int, int]]: """Get CUDA/HIP version as a tuple of (major, minor).""" try: if torch.version.cuda: @@ -73,4 +73,3 @@ def get_cuda_specs() -> Optional[CUDASpecs]: ) except Exception: return None -