-
-
Notifications
You must be signed in to change notification settings - Fork 826
Description
We currently build and package with our kernels built against the following versions of the CUDA Toolkit on all supported platforms:
- 11.8
- 12.0, 12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.8, 12.9
- 13.0
This adds excessive size to our wheels, along with excessive build times.
By default we try to load the binary built with the same CUDA Toolkit version as the user's PyTorch build, but this can be overriden with the BNB_CUDA_VERSION env variable.
Let's align better with the official PyTorch wheels, as these will be the most commonly used. We currently support PyTorch 2.3+ on CUDA.
| PyTorch Version | CUDA Toolkit Versions |
|---|---|
| 2.11 (Prerelease) | 12.6, 12.8, 13.0 |
| 2.10 | 12.6, 12.8, 13.0 |
| 2.9 | 12.6, 12.8, 13.0 |
| 2.8 | 12.6, 12.8, 12.9 |
| 2.7 | 11.8, 12.6, 12.8 |
| 2.6 | 11.8, 12.4, 12.6 |
| 2.5 | 11.8, 12.1, 12.4 |
| 2.4 | 11.8, 12.1, 12.4 |
| 2.3 | 11.8, 12.1 |
We would remove 4 of our 11 builds: CUDA 12.0, 12.2, 12.3, 12.5. When users happen to run a PyTorch built against a CUDA version without a matching BNB build available in their installation, we should then fallback to load the closest matching build version of bitsandbytes found, restricted to major CUDA version. We would first search for a compatible minor version lower than the PyTorch version, before moving on to minor versions above the PyTorch version. The BNB_CUDA_VERSION override would still supersede this.
| PyTorch CUDA | BNB Fallback Priority |
|---|---|
| 11.8 | - |
| 12.0 | 12.1, 12.2, ..., 12.9 |
| 12.1 | 12.0, 12.2, 12.3, ..., 12.9 |
| 12.2 | 12.1, 12.0, 12.3, ..., 12.9 |
| 12.3 | 12.2, 12.1, 12.0, 12.4, 12.5, ..., 12.9 |
| 12.4 | 12.3, 12.2, 12.1, 12.0, 12.5, ..., 12.9 |
| 12.5 | 12.4, 12.3, ..., 12.0, 12.6, 12.8, 12.9 |
| 12.6 | 12.5, 12.4, ..., 12.0, 12.8, 12.9 |
| 12.8 | 12.6, 12.5, ..., 12.0, 12.9 |
| 12.9 | 12.8, 12.6, ..., 12.0 |
| 13.0 | 13.1, 13.2, ... |
| 13.1 | 13.0, 13.2, ... |
This will reduce our unpacked wheel size by about 35% or ~73.4MB on Linux x86-64.