diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index ff90c50ed..64903cd49 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -21,26 +21,55 @@ def get_compute_capabilities() -> list[tuple[int, int]]: @lru_cache(None) -def get_cuda_version_tuple() -> tuple[int, int]: - if torch.version.cuda: - return tuple(map(int, torch.version.cuda.split(".")[0:2])) - elif torch.version.hip: - return tuple(map(int, torch.version.hip.split(".")[0:2])) +def get_cuda_version_tuple() -> Optional[tuple[int, int]]: + """Get CUDA/HIP version as a tuple of (major, minor).""" + try: + if torch.version.cuda: + version_str = torch.version.cuda + elif torch.version.hip: + version_str = torch.version.hip + else: + return None - return None + parts = version_str.split(".") + if len(parts) >= 2: + return tuple(map(int, parts[:2])) + return None + except (AttributeError, ValueError, IndexError): + return None -def get_cuda_version_string() -> str: - major, minor = get_cuda_version_tuple() +def get_cuda_version_string() -> Optional[str]: + """Get CUDA/HIP version as a string.""" + version_tuple = get_cuda_version_tuple() + if version_tuple is None: + return None + major, minor = version_tuple return f"{major * 10 + minor}" def get_cuda_specs() -> Optional[CUDASpecs]: + """Get CUDA/HIP specifications.""" if not torch.cuda.is_available(): return None - return CUDASpecs( - highest_compute_capability=(get_compute_capabilities()[-1]), - cuda_version_string=(get_cuda_version_string()), - cuda_version_tuple=get_cuda_version_tuple(), - ) + try: + compute_capabilities = get_compute_capabilities() + if not compute_capabilities: + return None + + version_tuple = get_cuda_version_tuple() + if version_tuple is None: + return None + + version_string = get_cuda_version_string() + if version_string is None: + return None + + return CUDASpecs( + highest_compute_capability=compute_capabilities[-1], + cuda_version_string=version_string, + cuda_version_tuple=version_tuple, + ) + except Exception: + return None