Skip to content

Comments

[Discussion] Unify CUDA and HIP kernel sources via compat.cuh portability layer#1877

Draft
Abdennacer-Badaoui wants to merge 4 commits intobitsandbytes-foundation:mainfrom
Abdennacer-Badaoui:merge-cuda-hip
Draft

[Discussion] Unify CUDA and HIP kernel sources via compat.cuh portability layer#1877
Abdennacer-Badaoui wants to merge 4 commits intobitsandbytes-foundation:mainfrom
Abdennacer-Badaoui:merge-cuda-hip

Conversation

@Abdennacer-Badaoui
Copy link
Member

@Abdennacer-Badaoui Abdennacer-Badaoui commented Feb 18, 2026

RFC — Not intended to be merged as-is

This PR proposes a design for merging the duplicated CUDA and HIP kernel sources into a single codebase. The csrc/examples/ directory contains the full unified files demonstrating the approach. This is meant for discussion and feedback before we proceed with a full migration.

Problem

We maintain near-identical copies of every GPU kernel:

CUDA HIP ~LOC each
kernels.cu kernels.hip 2600+
kernels.cuh kernels_hip.cuh 130
ops.cu ops.hip 650+
ops.cuh ops_hip.cuh 190
common.cuh common_hip.cuh 45 / 11

The HIP files were originally auto-generated by hipify and manually patched. Every bug fix or new feature must be applied to both copies, and they inevitably drift apart.

Proposed design

Introduce two portability headers:

  • compat.cuh — Host-safe types and macros (safe to include from .cpp files)
  • compat_device.cuh — Device-only layer: CUB/hipCUB, reduction ops, MMA (include from .cu files only)

These resolve all mechanical CUDA/HIP differences via macros, type aliases, and namespace aliases:

  • bnb_cub::cub:: on CUDA, hipcub:: on HIP
  • bnb_bfloat16__nv_bfloat16 on CUDA, hip_bfloat16 on HIP
  • bnb_stream_tcudaStream_t / hipStream_t
  • BNB_MAX_OPcub::Max() / hipcub::Max()
  • BNB_CHECK_RETURN() → unified error checking
  • bnb_blasLt*, bnb_sparse* → cuBLAS/hipBLAS and cuSPARSE/hipSPARSE

Kernel code uses these abstractions and compiles unmodified on both platforms. The <<<grid, block>>> launch syntax works natively on HIP, so no hipLaunchKernelGGL wrappers are needed.
For HIP builds, CMake simply sets LANGUAGE HIP on the .cu files.

#if BNB_HIP guards are only needed for genuinely divergent code (~10% of changes):

  • atomicMax (CUDA needs CAS loop, HIP has native)
  • Context class (cuBLAS vs rocBLAS handle creation)
  • igemmlt (hipBLAS requires explicit heuristic algo selection)
  • Warp-size-dependent kernels (unified via BNB_WARP_SIZE compile-time constants)

The split into two headers is necessary because .cpp files (like pythonInterface.cpp) are compiled by the host compiler (gcc/g++), which cannot parse CUB/device headers. Only .cu files
compiled by nvcc/hipcc include compat_device.cuh.

Example files in csrc/examples/

File Description
compat.cuh Host-safe portability header
compat_device.cuh Device-only portability header (CUB, reduction ops, MMA)
common_unified.cuh Merged common.cuh + common_hip.cuh
kernels_unified.cu Merged kernels.cu + kernels.hip
ops_unified.cuh Merged host API declarations + unified Context classes
ops_unified.cu Merged host wrappers (incl. #if BNB_HIP for divergent APIs)
pythonInterface_unified.cpp Updated to use unified headers
CMakeLists_unified.txt Updated build system (single GPU_FILES list)

End state after full migration

  • Delete 5 files: common_hip.cuh, kernels.hip, kernels_hip.cuh, ops.hip, ops_hip.cuh
  • Keep 5 files: common.cuh, kernels.cu, kernels.cuh, ops.cu, ops.cuh (now unified)
  • Add 2 files: compat.cuh, compat_device.cuh
  • Net result: 10 files → 7 files. The net reduction will be ~3300 lines.

@Abdennacer-Badaoui Abdennacer-Badaoui marked this pull request as draft February 18, 2026 13:53
@github-actions
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@TimDettmers TimDettmers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR Review: [Discussion] Unify CUDA and HIP kernel sources via compat.cuh portability layer

Classification: Refactoring / RFC (discussion-only, not intended to merge as-is)
Author: @Abdennacer-Badaoui (known contributor — authored the blocksize-32/64 kernels in #1854/#1856 that this unifies)
Risk level: Low (all files are additions in csrc/examples/, no existing code is modified)


Summary

This PR proposes a design for merging the duplicated CUDA and HIP kernel source files into a unified codebase using two new portability headers: compat.cuh (host-safe) and compat_device.cuh (device-only). The current codebase maintains near-identical copies of 5 pairs of files (~6500 LOC of duplication). The proposed approach would eliminate 5 files and ~3300 lines of duplication while introducing 2 new portability headers.

The 8 example files demonstrate the full approach. This is a well-structured RFC that shows rather than tells.


CI Status

  • Lint: FAIL (expected — clang-format likely hasn't been run on the new files)
  • build-wheels: FAIL (unrelated — dependency on lint)
  • All CUDA/HIP/CPU build & test jobs: PASS (these don't compile csrc/examples/)

The lint failure is expected for an RFC and is not a concern at this stage.


Design Assessment

The two-header split (compat.cuh for host-safe code, compat_device.cuh for device-only CUB/MMA) is a clean design. The rationale is solid: .cpp files compiled by gcc/g++ cannot parse CUDA device headers, so the split is necessary.

Strengths:

  1. Namespace aliasing for CUB/hipCUB (namespace bnb_cub = cub/hipcub) eliminates ~90% of the mechanical cub:: vs hipcub:: differences with a single line. Elegant.

  2. Compile-time BNB_WARP_SIZE in common_unified.cuh correctly handles the GFX9 (CDNA) 64-wide warps vs RDNA/CUDA 32-wide warps. The #ifdef __GFX9__ guard is correct for current ROCm architectures.

  3. kQuantizeBlockwiseSmall successfully unifies kQuantizeBlockwise32 (CUDA) and kQuantizeBlockwise64 (HIP) by parameterizing on BNB_WARP_SIZE. The kernel logic is structurally identical to both originals — I verified the codebook values, reduction ops, quantization packing, and store patterns match.

  4. #if BNB_HIP guards are used sparingly and only where genuinely needed:

    • atomicMax (CUDA CAS loop vs HIP native)
    • Context class (cuBLAS vs rocBLAS handle creation)
    • gemmex/strided_gemmex (different BLAS APIs)
    • igemmlt (hipBLAS requires explicit heuristic algo selection)
    • blocksize==64 dispatch path in ops_unified.cu (only HIP with 64-wide warps needs the small-block kernel for blocksize=64)
  5. CMakeLists change is minimal and correct: single GPU_FILES list replaces separate CUDA_FILES/HIP_FILES, with set_source_files_properties(${GPU_FILES} PROPERTIES LANGUAGE HIP) for HIP builds. The <<<>>> launch syntax works natively on HIP, so no hipLaunchKernelGGL wrappers are needed.

Technical concerns (for discussion):

  1. BNB_WARP_SIZE and blocksize=64 dispatch: In ops_unified.cu lines 50-61, the blocksize==64 path has a #if BNB_HIP guard to dispatch to kQuantizeBlockwiseSmall for 4-bit types on HIP. However, BNB_WARP_SIZE is a device-side macro (__GFX9__ is only defined in device code), while this dispatch decision is made in host code. How will the host-side code know whether to use the warp-64 path? The current approach uses #if BNB_HIP as a proxy, which is correct if the library is compiled separately for each target architecture, but could be wrong if a single HIP binary targets both CDNA (warp64) and RDNA (warp32) architectures simultaneously. This probably needs a runtime check or separate compilation for each arch, or a comment explaining the assumption.

  2. kQuantizeBlockwiseSmall name: The kernel is called "Small" but on HIP with warp=64, it handles blocksize=64 (not small at all). Consider kQuantizeBlockwiseWarp or similar to reflect that it processes warp-sized blocks. Minor naming nit.

  3. compat.cuh includes rocblas/rocblas.h and hipblas/hipblas.h unconditionally on HIP: These are heavyweight headers. If compat.cuh is meant to be "host-safe and lightweight," consider whether these BLAS includes belong here or in a separate BLAS compat header. Currently the Context class in ops_unified.cuh needs them, but other files including compat.cuh may not.

  4. BNB_BLASLT_PTR_MODE_ALPHA_VEC asymmetry: On CUDA this maps to CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO, on HIP to HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST. The BETA_ZERO vs BETA_HOST difference is notable — is this an intentional difference in how the two backends handle beta, or should it be BETA_ZERO on both? This discrepancy exists in the current code, so it's not introduced by this PR, but the unification is a good opportunity to document why.

  5. Missing bnb_blasLtPrefCreate/bnb_blasLtPrefSetAttr/bnb_blasLtAlgoGetHeuristic macros for CUDA: These are defined for HIP in compat.cuh but not for CUDA, because CUDA doesn't need the heuristic path. However, they're used inside a #if BNB_HIP block in ops_unified.cu, so there's no build failure — but it means the compat header is incomplete if someone tried to use these macros on CUDA. Add a comment or #ifdef guard noting these are HIP-only.

  6. CUDA_CHECK_RETURN backward compat macro: Good that compat.cuh defines #define CUDA_CHECK_RETURN(value) BNB_CHECK_RETURN(value) for migration purposes. This should be documented as deprecated and removed after the full migration.


Security Review

  • No network access, command execution, or dynamic code execution introduced
  • No new dependencies added
  • No changes to pyproject.toml, CI workflows, or agent configuration files
  • No invisible Unicode characters detected in any file
  • Codebook values (FP4 and NF4 lookup tables) are byte-identical to the existing kernels.cu
  • CMakeLists changes are limited to file list unification — no new execute_process, FetchContent, or custom commands
  • Build flags unchanged

No security concerns.


Numerical Correctness

All quantization/dequantization kernel code is mechanically equivalent to the existing CUDA and HIP kernels. Specifically verified:

  • fp4_dequantization_lut and nf4_dequantization_lut values are identical
  • dQuantizeFP4, dQuantizeNF4, dDequantizeFP4Tree, dDequantizeNF4 logic is identical
  • atomicMax CAS loop is correctly guarded with #if !BNB_HIP
  • kQuantizeBlockwise template uses bnb_cub:: and BNB_MAX_OP as 1:1 replacements
  • kQuantizeBlockwiseSmall logic matches both kQuantizeBlockwise32 (CUDA) and kQuantizeBlockwise64 (HIP)
  • igemmlt preserves the HIP heuristic path and CUDA direct path

No numerical correctness concerns.


Downstream Impact

None. This PR adds files to csrc/examples/ — it does not modify any compiled source, public API, or serialization format. No downstream impact.


Cross-PR Conflicts

PR #1858 (k-bit blockwise quantization kernels) adds new CUDA kernels. If this RFC proceeds to full migration, the new kernels from #1858 would need to be written using the compat.cuh abstractions rather than raw CUDA APIs. Worth noting for sequencing.


Verdict: APPROVE (as RFC)

This is a well-designed RFC. The portability layer approach is sound, the #if BNB_HIP guards are minimal and limited to genuinely divergent code, and the unified kernel code is a faithful merge of the existing CUDA and HIP sources. The concerns listed above are discussion points for the design, not blockers.

For the full migration, I'd recommend:

  1. Resolve the warp-size host/device detection question (concern #1 above)
  2. Add compilation tests that verify the unified files build correctly for both CUDA and HIP
  3. Run the full test suite on both CUDA and ROCm hardware to verify numerical equivalence
  4. Sequence this after or coordinate with #1858 to avoid rework

@Abdennacer-Badaoui Abdennacer-Badaoui added the RFC request for comments on proposed library improvements label Feb 20, 2026
@matthewdouglas
Copy link
Member

@Abdennacer-Badaoui Thanks! This is essentially what I was expecting we could do. I think this is a good way forward. Most of the review comments above make sense as well!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

RFC request for comments on proposed library improvements ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants