From 4858bdb5eba95609a90f2224f5ed73ba849b1432 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Thu, 29 Aug 2024 11:19:25 -0400 Subject: [PATCH] Incorporate vllm 0.5.5 kernels --- CMakeLists.txt | 322 +++ cmake/utils.cmake | 367 +++ csrc/activation_kernels.cu | 162 ++ .../attention/attention_dtypes.h | 3 +- .../attention/attention_generic.cuh | 19 +- csrc/attention/attention_kernels.cu | 1002 +++++++ .../attention/attention_utils.cuh | 13 +- .../attention/dtype_bfloat16.cuh | 82 +- .../attention/dtype_float16.cuh | 92 +- .../attention/dtype_float32.cuh | 88 +- csrc/attention/dtype_fp8.cuh | 41 + csrc/cache.h | 33 + csrc/cache_kernels.cu | 405 +++ csrc/core/registration.h | 22 + csrc/core/scalar_type.hpp | 547 ++++ csrc/core/torch_bindings.cpp | 16 + csrc/cpu/activation.cpp | 163 ++ csrc/cpu/attention.cpp | 758 ++++++ csrc/cpu/cache.cpp | 138 + csrc/cpu/cpu_types.hpp | 15 + csrc/cpu/cpu_types_vsx.hpp | 491 ++++ csrc/cpu/cpu_types_x86.hpp | 515 ++++ csrc/cpu/layernorm.cpp | 117 + csrc/cpu/pos_encoding.cpp | 199 ++ csrc/cpu/torch_bindings.cpp | 117 + csrc/cpu/utils.cpp | 65 + csrc/{paged_attention => }/cuda_compat.h | 26 +- csrc/cuda_utils.h | 15 + csrc/cuda_utils_kernels.cu | 29 + csrc/custom_all_reduce.cu | 153 ++ csrc/custom_all_reduce.cuh | 482 ++++ csrc/custom_all_reduce_test.cu | 316 +++ csrc/cutlass_extensions/cute_utils.cuh | 68 + csrc/cutlass_extensions/torch_utils.hpp | 154 ++ .../vllm_collective_builder.cuh | 43 + csrc/cutlass_extensions/vllm_custom_types.cuh | 50 + .../vllm_cutlass_library_extension.py | 49 + .../vllm_numeric_conversion.cuh | 795 ++++++ csrc/dispatch_utils.h | 35 + csrc/layernorm_kernels.cu | 357 +++ csrc/moe/moe_ops.h | 7 + csrc/moe/topk_softmax_kernels.cu | 506 ++++ csrc/moe/torch_bindings.cpp | 12 + csrc/moe_align_block_size_kernels.cu | 134 + csrc/ops.h | 218 ++ csrc/paged_attention/README.md | 3 - .../attention/attention_kernels.cu | 882 ------- csrc/paged_attention/cache.h | 18 - csrc/paged_attention/cache_kernels.cu | 190 -- csrc/paged_attention/cuda_utils.h | 7 - csrc/paged_attention/cuda_utils_kernels.cu | 17 - csrc/paged_attention/dispatch_utils.h | 16 - csrc/paged_attention/ops.h | 32 - csrc/paged_attention/pybind.cpp | 38 - csrc/paged_attention/reduction_utils.cuh | 51 - csrc/pos_encoding_kernels.cu | 203 ++ csrc/prepare_inputs/advance_step.cu | 131 + csrc/prepare_inputs/advance_step.cuh | 19 + csrc/quantization/aqlm/gemm_kernels.cu | 597 +++++ csrc/quantization/awq/dequantize.cuh | 102 + csrc/quantization/awq/gemm_kernels.cu | 526 ++++ .../compressed_tensors/int8_quant_kernels.cu | 125 + csrc/quantization/cutlass_w8a8/Epilogues.md | 147 ++ .../broadcast_load_epilogue_c2x.hpp | 496 ++++ .../broadcast_load_epilogue_c3x.hpp | 447 ++++ csrc/quantization/cutlass_w8a8/common.hpp | 27 + .../cutlass_w8a8/scaled_mm_c2x.cu | 198 ++ .../cutlass_w8a8/scaled_mm_c2x.cuh | 521 ++++ .../scaled_mm_c2x_sm75_dispatch.cuh | 123 + .../scaled_mm_c2x_sm80_dispatch.cuh | 139 + .../scaled_mm_c2x_sm89_fp8_dispatch.cuh | 368 +++ .../scaled_mm_c2x_sm89_int8_dispatch.cuh | 353 +++ .../cutlass_w8a8/scaled_mm_c3x.cu | 751 ++++++ .../cutlass_w8a8/scaled_mm_entry.cu | 198 ++ csrc/quantization/fp8/amd/hip_float8.h | 137 + csrc/quantization/fp8/amd/hip_float8_impl.h | 316 +++ csrc/quantization/fp8/amd/quant_utils.cuh | 577 +++++ csrc/quantization/fp8/common.cu | 316 +++ csrc/quantization/fp8/fp8_marlin.cu | 1305 ++++++++++ csrc/quantization/fp8/nvidia/quant_utils.cuh | 573 ++++ csrc/quantization/gguf/dequantize.cuh | 531 ++++ csrc/quantization/gguf/ggml-common.h | 969 +++++++ csrc/quantization/gguf/gguf_kernel.cu | 242 ++ csrc/quantization/gguf/mmq.cuh | 600 +++++ csrc/quantization/gguf/mmvq.cuh | 182 ++ csrc/quantization/gguf/vecdotq.cuh | 1745 +++++++++++++ csrc/quantization/gptq/compat.cuh | 64 + csrc/quantization/gptq/matrix_view.cuh | 295 +++ csrc/quantization/gptq/q_gemm.cu | 1856 +++++++++++++ csrc/quantization/gptq/qdq_2.cuh | 76 + csrc/quantization/gptq/qdq_3.cuh | 149 ++ csrc/quantization/gptq/qdq_4.cuh | 126 + csrc/quantization/gptq/qdq_8.cuh | 30 + csrc/quantization/gptq/qdq_util.cuh | 56 + .../gptq_marlin/awq_marlin_repack.cu | 269 ++ csrc/quantization/gptq_marlin/gptq_marlin.cu | 2299 +++++++++++++++++ .../gptq_marlin/gptq_marlin_repack.cu | 344 +++ csrc/quantization/gptq_marlin/marlin.cuh | 87 + .../gptq_marlin/marlin_dtypes.cuh | 79 + csrc/quantization/machete/Readme.md | 45 + csrc/quantization/machete/generate.py | 446 ++++ .../machete/generated/machete_mm_bf16u4.cu | 70 + .../generated/machete_mm_bf16u4_impl_part0.cu | 78 + .../generated/machete_mm_bf16u4_impl_part1.cu | 78 + .../machete/generated/machete_mm_bf16u4b8.cu | 70 + .../machete_mm_bf16u4b8_impl_part0.cu | 78 + .../machete_mm_bf16u4b8_impl_part1.cu | 78 + .../machete/generated/machete_mm_bf16u8.cu | 70 + .../generated/machete_mm_bf16u8_impl_part0.cu | 78 + .../generated/machete_mm_bf16u8_impl_part1.cu | 78 + .../generated/machete_mm_bf16u8b128.cu | 70 + .../machete_mm_bf16u8b128_impl_part0.cu | 78 + .../machete_mm_bf16u8b128_impl_part1.cu | 78 + .../machete/generated/machete_mm_f16u4.cu | 70 + .../generated/machete_mm_f16u4_impl_part0.cu | 78 + .../generated/machete_mm_f16u4_impl_part1.cu | 78 + .../machete/generated/machete_mm_f16u4b8.cu | 70 + .../machete_mm_f16u4b8_impl_part0.cu | 78 + .../machete_mm_f16u4b8_impl_part1.cu | 78 + .../machete/generated/machete_mm_f16u8.cu | 70 + .../generated/machete_mm_f16u8_impl_part0.cu | 78 + .../generated/machete_mm_f16u8_impl_part1.cu | 78 + .../machete/generated/machete_mm_f16u8b128.cu | 70 + .../machete_mm_f16u8b128_impl_part0.cu | 78 + .../machete_mm_f16u8b128_impl_part1.cu | 78 + .../generated/machete_prepack_bf16u4.cu | 25 + .../generated/machete_prepack_bf16u4b8.cu | 25 + .../generated/machete_prepack_bf16u8.cu | 25 + .../generated/machete_prepack_bf16u8b128.cu | 25 + .../generated/machete_prepack_f16u4.cu | 25 + .../generated/machete_prepack_f16u4b8.cu | 25 + .../generated/machete_prepack_f16u8.cu | 25 + .../generated/machete_prepack_f16u8b128.cu | 25 + .../machete/machete_collective_builder.cuh | 33 + .../machete/machete_interleaving_utils.cuh | 35 + .../quantization/machete/machete_mainloop.cuh | 1473 +++++++++++ .../machete/machete_mm_kernel.cuh | 237 ++ .../machete/machete_mm_launcher.cuh | 95 + .../machete/machete_prepack_kernel.cuh | 62 + .../machete/machete_prepack_launcher.cuh | 71 + .../machete/machete_prepacked_layout.cuh | 220 ++ csrc/quantization/machete/machete_pytorch.cu | 91 + csrc/quantization/marlin/dense/LICENSE | 209 ++ csrc/quantization/marlin/dense/common/base.h | 32 + csrc/quantization/marlin/dense/common/mem.h | 89 + .../marlin/dense/marlin_cuda_kernel.cu | 1068 ++++++++ .../marlin/qqq/marlin_qqq_gemm_kernel.cu | 1243 +++++++++ csrc/quantization/marlin/sparse/LICENSE | 203 ++ csrc/quantization/marlin/sparse/common/base.h | 51 + csrc/quantization/marlin/sparse/common/mem.h | 136 + csrc/quantization/marlin/sparse/common/mma.h | 191 ++ .../marlin/sparse/marlin_24_cuda_kernel.cu | 1136 ++++++++ .../squeezellm/quant_cuda_kernel.cu | 216 ++ csrc/torch_bindings.cpp | 354 +++ fms_extras/_core_ext.py | 271 ++ fms_extras/_custom_ops.py | 618 +++++ fms_extras/utils/cache/paged.py | 281 +- pyproject.toml | 2 +- requirements-build.txt | 12 +- requirements.txt | 2 +- setup.py | 409 ++- 161 files changed, 38631 insertions(+), 1897 deletions(-) create mode 100644 CMakeLists.txt create mode 100644 cmake/utils.cmake create mode 100644 csrc/activation_kernels.cu rename csrc/{paged_attention => }/attention/attention_dtypes.h (65%) rename csrc/{paged_attention => }/attention/attention_generic.cuh (77%) create mode 100644 csrc/attention/attention_kernels.cu rename csrc/{paged_attention => }/attention/attention_utils.cuh (78%) rename csrc/{paged_attention => }/attention/dtype_bfloat16.cuh (87%) rename csrc/{paged_attention => }/attention/dtype_float16.cuh (88%) rename csrc/{paged_attention => }/attention/dtype_float32.cuh (76%) create mode 100644 csrc/attention/dtype_fp8.cuh create mode 100644 csrc/cache.h create mode 100644 csrc/cache_kernels.cu create mode 100644 csrc/core/registration.h create mode 100644 csrc/core/scalar_type.hpp create mode 100644 csrc/core/torch_bindings.cpp create mode 100644 csrc/cpu/activation.cpp create mode 100644 csrc/cpu/attention.cpp create mode 100644 csrc/cpu/cache.cpp create mode 100644 csrc/cpu/cpu_types.hpp create mode 100644 csrc/cpu/cpu_types_vsx.hpp create mode 100644 csrc/cpu/cpu_types_x86.hpp create mode 100644 csrc/cpu/layernorm.cpp create mode 100644 csrc/cpu/pos_encoding.cpp create mode 100644 csrc/cpu/torch_bindings.cpp create mode 100644 csrc/cpu/utils.cpp rename csrc/{paged_attention => }/cuda_compat.h (52%) create mode 100644 csrc/cuda_utils.h create mode 100644 csrc/cuda_utils_kernels.cu create mode 100644 csrc/custom_all_reduce.cu create mode 100644 csrc/custom_all_reduce.cuh create mode 100644 csrc/custom_all_reduce_test.cu create mode 100644 csrc/cutlass_extensions/cute_utils.cuh create mode 100644 csrc/cutlass_extensions/torch_utils.hpp create mode 100644 csrc/cutlass_extensions/vllm_collective_builder.cuh create mode 100644 csrc/cutlass_extensions/vllm_custom_types.cuh create mode 100644 csrc/cutlass_extensions/vllm_cutlass_library_extension.py create mode 100644 csrc/cutlass_extensions/vllm_numeric_conversion.cuh create mode 100644 csrc/dispatch_utils.h create mode 100644 csrc/layernorm_kernels.cu create mode 100644 csrc/moe/moe_ops.h create mode 100644 csrc/moe/topk_softmax_kernels.cu create mode 100644 csrc/moe/torch_bindings.cpp create mode 100644 csrc/moe_align_block_size_kernels.cu create mode 100644 csrc/ops.h delete mode 100644 csrc/paged_attention/README.md delete mode 100644 csrc/paged_attention/attention/attention_kernels.cu delete mode 100644 csrc/paged_attention/cache.h delete mode 100644 csrc/paged_attention/cache_kernels.cu delete mode 100644 csrc/paged_attention/cuda_utils.h delete mode 100644 csrc/paged_attention/cuda_utils_kernels.cu delete mode 100644 csrc/paged_attention/dispatch_utils.h delete mode 100644 csrc/paged_attention/ops.h delete mode 100644 csrc/paged_attention/pybind.cpp delete mode 100644 csrc/paged_attention/reduction_utils.cuh create mode 100644 csrc/pos_encoding_kernels.cu create mode 100644 csrc/prepare_inputs/advance_step.cu create mode 100644 csrc/prepare_inputs/advance_step.cuh create mode 100644 csrc/quantization/aqlm/gemm_kernels.cu create mode 100644 csrc/quantization/awq/dequantize.cuh create mode 100644 csrc/quantization/awq/gemm_kernels.cu create mode 100644 csrc/quantization/compressed_tensors/int8_quant_kernels.cu create mode 100644 csrc/quantization/cutlass_w8a8/Epilogues.md create mode 100644 csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp create mode 100644 csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp create mode 100644 csrc/quantization/cutlass_w8a8/common.hpp create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu create mode 100644 csrc/quantization/fp8/amd/hip_float8.h create mode 100644 csrc/quantization/fp8/amd/hip_float8_impl.h create mode 100644 csrc/quantization/fp8/amd/quant_utils.cuh create mode 100644 csrc/quantization/fp8/common.cu create mode 100644 csrc/quantization/fp8/fp8_marlin.cu create mode 100644 csrc/quantization/fp8/nvidia/quant_utils.cuh create mode 100644 csrc/quantization/gguf/dequantize.cuh create mode 100644 csrc/quantization/gguf/ggml-common.h create mode 100644 csrc/quantization/gguf/gguf_kernel.cu create mode 100644 csrc/quantization/gguf/mmq.cuh create mode 100644 csrc/quantization/gguf/mmvq.cuh create mode 100644 csrc/quantization/gguf/vecdotq.cuh create mode 100644 csrc/quantization/gptq/compat.cuh create mode 100644 csrc/quantization/gptq/matrix_view.cuh create mode 100644 csrc/quantization/gptq/q_gemm.cu create mode 100644 csrc/quantization/gptq/qdq_2.cuh create mode 100644 csrc/quantization/gptq/qdq_3.cuh create mode 100644 csrc/quantization/gptq/qdq_4.cuh create mode 100644 csrc/quantization/gptq/qdq_8.cuh create mode 100644 csrc/quantization/gptq/qdq_util.cuh create mode 100644 csrc/quantization/gptq_marlin/awq_marlin_repack.cu create mode 100644 csrc/quantization/gptq_marlin/gptq_marlin.cu create mode 100644 csrc/quantization/gptq_marlin/gptq_marlin_repack.cu create mode 100644 csrc/quantization/gptq_marlin/marlin.cuh create mode 100644 csrc/quantization/gptq_marlin/marlin_dtypes.cuh create mode 100644 csrc/quantization/machete/Readme.md create mode 100644 csrc/quantization/machete/generate.py create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u4.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u4b8.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u4b8_impl_part0.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u4b8_impl_part1.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u8.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u8_impl_part0.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u8_impl_part1.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u8b128.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u8b128_impl_part0.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_bf16u8b128_impl_part1.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u4.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u4_impl_part0.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u4_impl_part1.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u4b8.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u4b8_impl_part0.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u4b8_impl_part1.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u8.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u8_impl_part0.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u8_impl_part1.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u8b128.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u8b128_impl_part0.cu create mode 100644 csrc/quantization/machete/generated/machete_mm_f16u8b128_impl_part1.cu create mode 100644 csrc/quantization/machete/generated/machete_prepack_bf16u4.cu create mode 100644 csrc/quantization/machete/generated/machete_prepack_bf16u4b8.cu create mode 100644 csrc/quantization/machete/generated/machete_prepack_bf16u8.cu create mode 100644 csrc/quantization/machete/generated/machete_prepack_bf16u8b128.cu create mode 100644 csrc/quantization/machete/generated/machete_prepack_f16u4.cu create mode 100644 csrc/quantization/machete/generated/machete_prepack_f16u4b8.cu create mode 100644 csrc/quantization/machete/generated/machete_prepack_f16u8.cu create mode 100644 csrc/quantization/machete/generated/machete_prepack_f16u8b128.cu create mode 100644 csrc/quantization/machete/machete_collective_builder.cuh create mode 100644 csrc/quantization/machete/machete_interleaving_utils.cuh create mode 100644 csrc/quantization/machete/machete_mainloop.cuh create mode 100644 csrc/quantization/machete/machete_mm_kernel.cuh create mode 100644 csrc/quantization/machete/machete_mm_launcher.cuh create mode 100644 csrc/quantization/machete/machete_prepack_kernel.cuh create mode 100644 csrc/quantization/machete/machete_prepack_launcher.cuh create mode 100644 csrc/quantization/machete/machete_prepacked_layout.cuh create mode 100644 csrc/quantization/machete/machete_pytorch.cu create mode 100644 csrc/quantization/marlin/dense/LICENSE create mode 100644 csrc/quantization/marlin/dense/common/base.h create mode 100644 csrc/quantization/marlin/dense/common/mem.h create mode 100644 csrc/quantization/marlin/dense/marlin_cuda_kernel.cu create mode 100644 csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu create mode 100644 csrc/quantization/marlin/sparse/LICENSE create mode 100644 csrc/quantization/marlin/sparse/common/base.h create mode 100644 csrc/quantization/marlin/sparse/common/mem.h create mode 100644 csrc/quantization/marlin/sparse/common/mma.h create mode 100644 csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu create mode 100644 csrc/quantization/squeezellm/quant_cuda_kernel.cu create mode 100644 csrc/torch_bindings.cpp create mode 100644 fms_extras/_core_ext.py create mode 100644 fms_extras/_custom_ops.py diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..35c3a6e --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,322 @@ +cmake_minimum_required(VERSION 3.26) + +project(vllm_extensions LANGUAGES CXX) + +# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) +set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") + +message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") +message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") + +include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) + +# Suppress potential warnings about unused manually-specified variables +set(ignoreMe "${VLLM_PYTHON_PATH}") + +# +# Supported python versions. These versions will be searched in order, the +# first match will be selected. These should be kept in sync with setup.py. +# +set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12") + +# Supported NVIDIA architectures. +set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") + +# Supported AMD GPU architectures. +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100") + +# +# Supported/expected torch versions for CUDA/ROCm. +# +# Currently, having an incorrect pytorch version results in a warning +# rather than an error. +# +# Note: the CUDA torch version is derived from pyproject.toml and various +# requirements.txt files and should be kept consistent. The ROCm torch +# versions are derived from Dockerfile.rocm +# +set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0") +set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0") + +# +# Try to find python package with an executable that exactly matches +# `VLLM_PYTHON_EXECUTABLE` and is one of the supported versions. +# +if (VLLM_PYTHON_EXECUTABLE) + find_python_from_executable(${VLLM_PYTHON_EXECUTABLE} "${PYTHON_SUPPORTED_VERSIONS}") +else() + message(FATAL_ERROR + "Please set VLLM_PYTHON_EXECUTABLE to the path of the desired python version" + " before running cmake configure.") +endif() + +# +# Update cmake's `CMAKE_PREFIX_PATH` with torch location. +# +append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") + +# Ensure the 'nvcc' command is in the PATH +find_program(NVCC_EXECUTABLE nvcc) +if (CUDA_FOUND AND NOT NVCC_EXECUTABLE) + message(FATAL_ERROR "nvcc not found") +endif() + +# +# Import torch cmake configuration. +# Torch also imports CUDA (and partially HIP) languages with some customizations, +# so there is no need to do this explicitly with check_language/enable_language, +# etc. +# +find_package(Torch REQUIRED) + +# +# Add the `default` target which detects which extensions should be +# built based on platform/architecture. This is the same logic that +# setup.py uses to select which extensions should be built and should +# be kept in sync. +# +# The `default` target makes direct use of cmake easier since knowledge +# of which extensions are supported has been factored in, e.g. +# +# mkdir build && cd build +# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. +# cmake --build . --target default +# +add_custom_target(default) +message(STATUS "Enabling core extension.") + +# Define _core_C extension +# built for (almost) every target platform, (excludes TPU and Neuron) + +set(VLLM_EXT_SRC + "csrc/core/torch_bindings.cpp") + +set(DESTINATION_NAME + "fms_extras") + +define_gpu_extension_target( + _core_C + DESTINATION ${DESTINATION_NAME} + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS} + USE_SABI 3 + WITH_SOABI) + +add_dependencies(default _core_C) + +# +# Forward the non-CUDA device extensions to external CMake scripts. +# +if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND + NOT VLLM_TARGET_DEVICE STREQUAL "rocm") + if (VLLM_TARGET_DEVICE STREQUAL "cpu") + include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) + else() + return() + endif() + return() +endif() + +# +# Set up GPU language and check the torch version and warn if it isn't +# what is expected. +# +if (NOT HIP_FOUND AND CUDA_FOUND) + set(VLLM_GPU_LANG "CUDA") + + if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA}) + message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} " + "expected for CUDA build, saw ${Torch_VERSION} instead.") + endif() +elseif(HIP_FOUND) + set(VLLM_GPU_LANG "HIP") + + # Importing torch recognizes and sets up some HIP/ROCm configuration but does + # not let cmake recognize .hip files. In order to get cmake to understand the + # .hip extension automatically, HIP must be enabled explicitly. + enable_language(HIP) + + # ROCm 5.X and 6.X + if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND + NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM}) + message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} " + "expected for ROCm build, saw ${Torch_VERSION} instead.") + endif() +else() + message(FATAL_ERROR "Can't find CUDA or HIP installation.") +endif() + +# +# Override the GPU architectures detected by cmake/torch and filter them by +# the supported versions for the current language. +# The final set of arches is stored in `VLLM_GPU_ARCHES`. +# +override_gpu_arches(VLLM_GPU_ARCHES + ${VLLM_GPU_LANG} + "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") + +# +# Query torch for additional GPU compilation flags for the given +# `VLLM_GPU_LANG`. +# The final set of arches is stored in `VLLM_GPU_FLAGS`. +# +get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG}) + +# +# Set nvcc parallelism. +# +if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") + list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") +endif() + +# +# Define other extension targets +# + +# +# _C extension +# + +set(VLLM_EXT_SRC + "csrc/cache_kernels.cu" + "csrc/attention/attention_kernels.cu" + "csrc/pos_encoding_kernels.cu" + "csrc/activation_kernels.cu" + "csrc/layernorm_kernels.cu" + "csrc/quantization/squeezellm/quant_cuda_kernel.cu" + "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" + "csrc/quantization/fp8/common.cu" + "csrc/cuda_utils_kernels.cu" + "csrc/moe_align_block_size_kernels.cu" + "csrc/prepare_inputs/advance_step.cu" + "csrc/torch_bindings.cpp") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + include(FetchContent) + SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # CUTLASS 3.5.1 + GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 + GIT_PROGRESS TRUE + ) + FetchContent_MakeAvailable(cutlass) + + list(APPEND VLLM_EXT_SRC + "csrc/quantization/aqlm/gemm_kernels.cu" + "csrc/quantization/awq/gemm_kernels.cu" + "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" + "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" + "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" + "csrc/quantization/gptq_marlin/gptq_marlin.cu" + "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" + "csrc/quantization/gptq_marlin/awq_marlin_repack.cu" + "csrc/quantization/gguf/gguf_kernel.cu" + "csrc/quantization/fp8/fp8_marlin.cu" + "csrc/custom_all_reduce.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + + # + # The CUTLASS kernels for Hopper require sm90a to be enabled. + # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. + # That adds an extra 17MB to compiled binary, so instead we selectively enable it. + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() + + # + # Machete kernels + + # The machete kernels only work on hopper and require CUDA 12.0 or later. + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + # + # For the Machete kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH + ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py + RESULT_VARIABLE machete_generation_result + OUTPUT_VARIABLE machete_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ) + + if (NOT machete_generation_result EQUAL 0) + message(FATAL_ERROR "Machete generation failed." + " Result: \"${machete_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") + else() + message(STATUS "Machete generation completed successfully.") + endif() + + # Add machete generated sources + file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu") + list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES}) + message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}") + + set_source_files_properties( + ${MACHETE_GEN_SOURCES} + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() + + # Add pytorch binding for machete (add on even CUDA < 12.0 so that we can + # raise an error if the user that this was built with an incompatible + # CUDA version) + list(APPEND VLLM_EXT_SRC + csrc/quantization/machete/machete_pytorch.cu) +endif() + +define_gpu_extension_target( + _C + DESTINATION ${DESTINATION_NAME} + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI) + +# +# _moe_C extension +# + +set(VLLM_MOE_EXT_SRC + "csrc/moe/torch_bindings.cpp" + "csrc/moe/topk_softmax_kernels.cu") + +define_gpu_extension_target( + _moe_C + DESTINATION ${DESTINATION_NAME} + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_MOE_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) + + + +if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") + message(STATUS "Enabling C extension.") + add_dependencies(default _C) + + message(STATUS "Enabling moe extension.") + add_dependencies(default _moe_C) + +endif() + diff --git a/cmake/utils.cmake b/cmake/utils.cmake new file mode 100644 index 0000000..a40224c --- /dev/null +++ b/cmake/utils.cmake @@ -0,0 +1,367 @@ +# +# Attempt to find the python package that uses the same python executable as +# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`. +# +macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) + file(REAL_PATH ${EXECUTABLE} EXECUTABLE) + set(Python_EXECUTABLE ${EXECUTABLE}) + find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule) + if (NOT Python_FOUND) + message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") + endif() + set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}") + set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN}) + if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST) + message(FATAL_ERROR + "Python version (${_VER}) is not one of the supported versions: " + "${_SUPPORTED_VERSIONS_LIST}.") + endif() + message(STATUS "Found python matching: ${EXECUTABLE}.") +endmacro() + +# +# Run `EXPR` in python. The standard output of python is stored in `OUT` and +# has trailing whitespace stripped. If an error is encountered when running +# python, a fatal message `ERR_MSG` is issued. +# +function (run_python OUT EXPR ERR_MSG) + execute_process( + COMMAND + "${Python_EXECUTABLE}" "-c" "${EXPR}" + OUTPUT_VARIABLE PYTHON_OUT + RESULT_VARIABLE PYTHON_ERROR_CODE + ERROR_VARIABLE PYTHON_STDERR + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT PYTHON_ERROR_CODE EQUAL 0) + message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}") + endif() + set(${OUT} ${PYTHON_OUT} PARENT_SCOPE) +endfunction() + +# Run `EXPR` in python after importing `PKG`. Use the result of this to extend +# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported. +macro (append_cmake_prefix_path PKG EXPR) + run_python(_PREFIX_PATH + "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path") + list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH}) +endmacro() + +# +# Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set +# of CUDA source files. The names of the corresponding "hipified" sources are +# stored in `OUT_SRCS`. +# +function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS) + # + # Split into C++ and non-C++ (i.e. CUDA) sources. + # + set(SRCS ${ORIG_SRCS}) + set(CXX_SRCS ${ORIG_SRCS}) + list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$") + list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$") + + # + # Generate ROCm/HIP source file names from CUDA file names. + # Since HIP files are generated code, they will appear in the build area + # `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir. + # + set(HIP_SRCS) + foreach (SRC ${SRCS}) + string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC}) + string(REGEX REPLACE "cuda" "hip" SRC ${SRC}) + list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}") + endforeach() + + set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc) + add_custom_target( + hipify${NAME} + COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS} + DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS} + BYPRODUCTS ${HIP_SRCS} + COMMENT "Running hipify on ${NAME} extension source files.") + + # Swap out original extension sources with hipified sources. + list(APPEND HIP_SRCS ${CXX_SRCS}) + set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE) +endfunction() + +# +# Get additional GPU compiler flags from torch. +# +function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) + if (${GPU_LANG} STREQUAL "CUDA") + # + # Get common NVCC flags from torch. + # + run_python(GPU_FLAGS + "from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))" + "Failed to determine torch nvcc compiler flags") + + if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) + list(APPEND GPU_FLAGS "-DENABLE_FP8") + endif() + if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) + list(REMOVE_ITEM GPU_FLAGS + "-D__CUDA_NO_HALF_OPERATORS__" + "-D__CUDA_NO_HALF_CONVERSIONS__" + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" + "-D__CUDA_NO_HALF2_OPERATORS__") + endif() + + elseif(${GPU_LANG} STREQUAL "HIP") + # + # Get common HIP/HIPCC flags from torch. + # + run_python(GPU_FLAGS + "import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))" + "Failed to determine torch nvcc compiler flags") + + list(APPEND GPU_FLAGS + "-DUSE_ROCM" + "-DENABLE_FP8" + "-U__HIP_NO_HALF_CONVERSIONS__" + "-U__HIP_NO_HALF_OPERATORS__" + "-fno-gpu-rdc") + + endif() + set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE) +endfunction() + +# Macro for converting a `gencode` version number to a cmake version number. +macro(string_to_ver OUT_VER IN_STR) + string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR}) +endmacro() + +# +# Override the GPU architectures detected by cmake/torch and filter them by +# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in +# `GPU_ARCHES`. +# +# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`. +# +macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) + set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN}) + message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}") + + if (${GPU_LANG} STREQUAL "HIP") + # + # `GPU_ARCHES` controls the `--offload-arch` flags. + # + # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list, + # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling + # "rocm_agent_enumerator" in "enable_language(HIP)" + # (in file Modules/CMakeDetermineHIPCompiler.cmake) + # + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH}) + else() + set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES}) + endif() + # + # Find the intersection of the supported + detected architectures to + # set the module architecture flags. + # + set(${GPU_ARCHES}) + foreach (_ARCH ${HIP_ARCHITECTURES}) + if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST) + list(APPEND ${GPU_ARCHES} ${_ARCH}) + endif() + endforeach() + + if(NOT ${GPU_ARCHES}) + message(FATAL_ERROR + "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is" + " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.") + endif() + + elseif(${GPU_LANG} STREQUAL "CUDA") + # + # Setup/process CUDA arch flags. + # + # The torch cmake setup hardcodes the detected architecture flags in + # `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it + # can't modified on a per-target basis. + # So, all the `-gencode` flags need to be extracted and removed from + # `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method. + # Since it's not possible to use `target_compiler_options` for adding target + # specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property + # must be used instead. This requires repackaging the architecture flags + # into a format that cmake expects for `CUDA_ARCHITECTURES`. + # + # This is a bit fragile in that it depends on torch using `-gencode` as opposed + # to one of the other nvcc options to specify architectures. + # + # Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override + # detected architectures. + # + message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") + + # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` + string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS + ${CMAKE_CUDA_FLAGS}) + + # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified + # and passed back via the `CUDA_ARCHITECTURES` property. + string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS + ${CMAKE_CUDA_FLAGS}) + + # If this error is triggered, it might mean that torch has changed how it sets + # up nvcc architecture code generation flags. + if (NOT _CUDA_ARCH_FLAGS) + message(FATAL_ERROR + "Could not find any architecture related code generation flags in " + "CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})") + endif() + + message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") + message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}") + + # Initialize the architecture lists to empty. + set(${GPU_ARCHES}) + + # Process each `gencode` flag. + foreach(_ARCH ${_CUDA_ARCH_FLAGS}) + # For each flag, extract the version number and whether it refers to PTX + # or native code. + # Note: if a regex matches then `CMAKE_MATCH_1` holds the binding + # for that match. + + string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH}) + if (_COMPUTE) + set(_COMPUTE ${CMAKE_MATCH_1}) + endif() + + string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH}) + if (_SM) + set(_SM ${CMAKE_MATCH_1}) + endif() + + string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH}) + if (_CODE) + set(_CODE ${CMAKE_MATCH_1}) + endif() + + # Make sure the virtual architecture can be matched. + if (NOT _COMPUTE) + message(FATAL_ERROR + "Could not determine virtual architecture from: ${_ARCH}.") + endif() + + # One of sm_ or compute_ must exist. + if ((NOT _SM) AND (NOT _CODE)) + message(FATAL_ERROR + "Could not determine a codegen architecture from: ${_ARCH}.") + endif() + + if (_SM) + # -real suffix let CMake to only generate elf code for the kernels. + # we want this, otherwise the added ptx (default) will increase binary size. + set(_VIRT "-real") + set(_CODE_ARCH ${_SM}) + else() + # -virtual suffix let CMake to generate ptx code for the kernels. + set(_VIRT "-virtual") + set(_CODE_ARCH ${_CODE}) + endif() + + # Check if the current version is in the supported arch list. + string_to_ver(_CODE_VER ${_CODE_ARCH}) + if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST) + message(STATUS "discarding unsupported CUDA arch ${_VER}.") + continue() + endif() + + # Add it to the arch list. + list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}") + endforeach() + endif() + message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}") +endmacro() + +# +# Define a target named `GPU_MOD_NAME` for a single extension. The +# arguments are: +# +# DESTINATION - Module destination directory. +# LANGUAGE - The GPU language for this module, e.g CUDA, HIP, +# etc. +# SOURCES - List of source files relative to CMakeLists.txt +# directory. +# +# Optional arguments: +# +# ARCHITECTURES - A list of target GPU architectures in cmake +# format. +# Refer `CMAKE_CUDA_ARCHITECTURES` documentation +# and `CMAKE_HIP_ARCHITECTURES` for more info. +# ARCHITECTURES will use cmake's defaults if +# not provided. +# COMPILE_FLAGS - Extra compiler flags passed to NVCC/hip. +# INCLUDE_DIRECTORIES - Extra include directories. +# LIBRARIES - Extra link libraries. +# WITH_SOABI - Generate library with python SOABI suffix name. +# USE_SABI - Use python stable api +# +# Note: optimization level/debug info is set via cmake build type. +# +function (define_gpu_extension_target GPU_MOD_NAME) + cmake_parse_arguments(PARSE_ARGV 1 + GPU + "WITH_SOABI" + "DESTINATION;LANGUAGE;USE_SABI" + "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") + + # Add hipify preprocessing step when building with HIP/ROCm. + if (GPU_LANGUAGE STREQUAL "HIP") + hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}") + endif() + + if (GPU_WITH_SOABI) + set(GPU_WITH_SOABI WITH_SOABI) + else() + set(GPU_WITH_SOABI) + endif() + + if (GPU_USE_SABI) + Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}") + else() + Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}") + endif() + + if (GPU_LANGUAGE STREQUAL "HIP") + # Make this target dependent on the hipify preprocessor step. + add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME}) + endif() + + if (GPU_ARCHITECTURES) + set_target_properties(${GPU_MOD_NAME} PROPERTIES + ${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}") + endif() + + set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17) + + target_compile_options(${GPU_MOD_NAME} PRIVATE + $<$:${GPU_COMPILE_FLAGS}>) + + target_compile_definitions(${GPU_MOD_NAME} PRIVATE + "-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}") + + target_include_directories(${GPU_MOD_NAME} PRIVATE csrc + ${GPU_INCLUDE_DIRECTORIES}) + + target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY} + ${GPU_LIBRARIES}) + + # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of + # dependencies that are not necessary and may not be installed. + if (GPU_LANGUAGE STREQUAL "CUDA") + target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB} + ${CUDA_LIBRARIES}) + else() + target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}) + endif() + + install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION}) +endfunction() + diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu new file mode 100644 index 0000000..5ed1dc3 --- /dev/null +++ b/csrc/activation_kernels.cu @@ -0,0 +1,162 @@ +#include +#include +#include + +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +namespace vllm { + +// Activation and gating kernel template. +template +__global__ void act_and_mul_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); + out[token_idx * d + idx] = ACT_FN(x) * y; + } +} + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +template +__device__ __forceinline__ T gelu_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'none' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 + const float f = (float)x; + constexpr float ALPHA = M_SQRT1_2; + return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); +} + +template +__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'tanh' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 + const float f = (float)x; + constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; + constexpr float KAPPA = 0.044715; + float x_cube = f * f * f; + float inner = BETA * (f + KAPPA * x_cube); + return (T)(0.5f * f * (1.0f + ::tanhf(inner))); +} + +} // namespace vllm + +// Launch activation and gating kernel. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + +void silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); +} + +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); +} + +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); +} + +namespace vllm { + +// Element-wise activation kernel template. +template +__global__ void activation_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); + out[token_idx * d + idx] = ACT_FN(x); + } +} + +} // namespace vllm + +// Launch element-wise activation kernel. +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / d; \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ + vllm::activation_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + +namespace vllm { + +template +__device__ __forceinline__ T gelu_new_kernel(const T& x) { + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); +} + +template +__device__ __forceinline__ T gelu_fast_kernel(const T& x) { + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); +} + +template +__device__ __forceinline__ T gelu_quick_kernel(const T& x) { + // x * sigmoid(1.702 * x) + return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x))); +} + +} // namespace vllm + +void gelu_new(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); +} + +void gelu_fast(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); +} + +void gelu_quick(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); +} diff --git a/csrc/paged_attention/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h similarity index 65% rename from csrc/paged_attention/attention/attention_dtypes.h rename to csrc/attention/attention_dtypes.h index 05e59a9..64f8638 100644 --- a/csrc/paged_attention/attention/attention_dtypes.h +++ b/csrc/attention/attention_dtypes.h @@ -3,4 +3,5 @@ #include "attention_generic.cuh" #include "dtype_float16.cuh" #include "dtype_float32.cuh" -#include "dtype_bfloat16.cuh" \ No newline at end of file +#include "dtype_bfloat16.cuh" +#include "dtype_fp8.cuh" diff --git a/csrc/paged_attention/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh similarity index 77% rename from csrc/paged_attention/attention/attention_generic.cuh rename to csrc/attention/attention_generic.cuh index d94440b..62409c0 100644 --- a/csrc/paged_attention/attention/attention_generic.cuh +++ b/csrc/attention/attention_generic.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -22,31 +23,31 @@ namespace vllm { // A vector type to store Q, K, V elements. -template +template struct Vec {}; // A vector type to store FP32 accumulators. -template +template struct FloatVec {}; // Template vector operations. -template +template inline __device__ Acc mul(A a, B b); -template +template inline __device__ float sum(T v); -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ void zero(T& dst) { constexpr int WORDS = sizeof(T) / 4; union { @@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) { dst = tmp.raw; } -} // namespace vllm \ No newline at end of file +} // namespace vllm diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu new file mode 100644 index 0000000..bcd1704 --- /dev/null +++ b/csrc/attention/attention_kernels.cu @@ -0,0 +1,1002 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +#ifdef USE_ROCM + #include + #include "../quantization/fp8/amd/quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Broadcast to other threads. + return VLLM_SHFL_SYNC(sum, 0); +} + +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). +template // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(cache_t); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, k_scale); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= seq_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = VLLM_SHFL_SYNC(qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using V_quant_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); + + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8::scaled_convert(v_quant_vec, + v_scale); + } + if (block_idx == num_seq_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += VLLM_SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +// Grid: (num_heads, num_seqs, 1). +template +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs). +template +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = VLLM_SHFL_SYNC(max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + +} // namespace vllm + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +// TODO(woosuk): Tune NUM_THREADS. +template +void paged_attention_v1_launcher( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V1(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) +} + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template +void paged_attention_v2_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V2(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V2(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v2( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/csrc/paged_attention/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh similarity index 78% rename from csrc/paged_attention/attention/attention_utils.cuh rename to csrc/attention/attention_utils.cuh index ff64c4b..826b0ed 100644 --- a/csrc/paged_attention/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -26,14 +27,14 @@ namespace vllm { // Q*K^T operation. -template +template inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { using A_vec = typename FloatVec::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). A_vec qk_vec = mul(q[0], k[0]); #pragma unroll for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); + qk_vec = vllm::fma(q[ii], k[ii], qk_vec); } // Finalize the reduction across lanes. @@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { return qk; } -template +template struct Qk_dot { - template + template static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { return qk_dot_(q, k); } }; -} // namespace vllm +} // namespace vllm diff --git a/csrc/paged_attention/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh similarity index 87% rename from csrc/paged_attention/attention/dtype_bfloat16.cuh rename to csrc/attention/dtype_bfloat16.cuh index d36ce15..97a25ba 100644 --- a/csrc/paged_attention/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -28,8 +30,8 @@ #include #include - typedef __hip_bfloat162 __nv_bfloat162; - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; #endif #include @@ -50,37 +52,37 @@ struct bf16_8_t { }; // BF16 vector types for Q, K, V. -template<> +template <> struct Vec<__nv_bfloat16, 1> { using Type = __nv_bfloat16; }; -template<> +template <> struct Vec<__nv_bfloat16, 2> { using Type = __nv_bfloat162; }; -template<> +template <> struct Vec<__nv_bfloat16, 4> { using Type = bf16_4_t; }; -template<> +template <> struct Vec<__nv_bfloat16, 8> { using Type = bf16_8_t; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec<__nv_bfloat16> { using Type = float; }; -template<> +template <> struct FloatVec<__nv_bfloat162> { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -92,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { #else return __bfloat1622float2(val); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { @@ -100,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { #else return __bfloat162bfloat162(val); #endif + __builtin_unreachable(); // Suppress missing return statement warning } // Vector addition. @@ -108,11 +112,12 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { assert(false); #else #ifndef USE_ROCM - return a + b; + return a + b; #else - return __hadd(a, b); + return __hadd(a, b); #endif #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { @@ -121,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { #else return __hadd2(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { @@ -161,30 +167,32 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else return __hmul(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else return __hmul2(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); } -template<> +template <> inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { bf16_4_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -192,7 +200,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_4_t c; @@ -201,7 +209,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { bf16_8_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -211,7 +219,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_8_t c; @@ -222,26 +230,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { float fa = __bfloat162float(a); float fb = __bfloat162float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { float2 fa = bf1622float2(a); float2 fb = bf1622float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul(bf162bf162(a), b); } -template<> +template <> inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -249,7 +257,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); Float4_ fc; @@ -258,7 +266,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -268,7 +276,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); Float8_ fc; @@ -280,20 +288,24 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { } // Vector fused multiply-add. -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else return __hfma2(a, b, c); #endif + __builtin_unreachable(); // Suppress missing return statement warning } -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else return __hfma2(bf162bf162(a), b, c); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { @@ -379,23 +391,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(__nv_bfloat16 v) { return __bfloat162float(v); } -template<> +template <> inline __device__ float sum(__nv_bfloat162 v) { float2 vf = bf1622float2(v); return vf.x + vf.y; } -template<> +template <> inline __device__ float sum(bf16_4_t v) { return sum(v.x) + sum(v.y); } -template<> +template <> inline __device__ float sum(bf16_8_t v) { return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); } @@ -448,4 +460,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { #endif } -} // namespace vllm \ No newline at end of file +} // namespace vllm diff --git a/csrc/paged_attention/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh similarity index 88% rename from csrc/paged_attention/attention/dtype_float16.cuh rename to csrc/attention/dtype_float16.cuh index 421419a..3a1815f 100644 --- a/csrc/paged_attention/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -30,37 +32,37 @@ namespace vllm { // FP16 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = uint16_t; }; -template<> +template <> struct Vec { using Type = uint32_t; }; -template<> +template <> struct Vec { using Type = uint2; }; -template<> +template <> struct Vec { using Type = uint4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { return b; #else union { - uint32_t u32; - uint16_t u16[2]; + uint32_t u32; + uint16_t u16[2]; } tmp; tmp.u16[0] = a; tmp.u16[1] = a; @@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) { } tmp; #ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif #else tmp.u16[0] = float_to_half(f.x); @@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { uint16_t c; #ifndef USE_ROCM @@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { uint32_t c; #ifndef USE_ROCM @@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ uint2 mul(uint2 a, uint2 b) { uint2 c; c.x = mul(a.x, b.x); @@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { return c; } -template<> +template <> inline __device__ uint2 mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); uint2 c; @@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint4 a, uint4 b) { uint4 c; c.x = mul(a.x, b.x); @@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); uint4 c; @@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { return c; } -template<> +template <> inline __device__ float mul(uint16_t a, uint16_t b) { float fa = half_to_float(a); float fb = half_to_float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(uint32_t a, uint32_t b) { float2 fa = half2_to_float2(a); float2 fb = half2_to_float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ Float4_ mul(uint2 a, uint2 b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); Float4_ fc; @@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint4 a, uint4 b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); Float8_ fc; @@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; #ifndef USE_ROCM - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); #else - asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" + : "=v"(d) + : "v"(a), "v"(b), "v"(c)); #endif return d; } @@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(uint16_t v) { return half_to_float(v); } -template<> +template <> inline __device__ float sum(uint32_t v) { float2 tmp = half2_to_float2(v); return tmp.x + tmp.y; } -template<> +template <> inline __device__ float sum(uint2 v) { uint32_t c = add(v.x, v.y); return sum(c); } -template<> +template <> inline __device__ float sum(uint4 v) { uint32_t c = add(v.x, v.y); c = add(c, v.z); @@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { } // From float16 to float32. -inline __device__ float to_float(uint16_t u) { - return half_to_float(u); -} +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } -inline __device__ float2 to_float(uint32_t u) { - return half2_to_float2(u); -} +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } inline __device__ Float4_ to_float(uint2 u) { Float4_ tmp; @@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) { } // Zero-out a variable. -inline __device__ void zero(uint16_t& dst) { - dst = uint16_t(0); -} +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm \ No newline at end of file +} // namespace vllm diff --git a/csrc/paged_attention/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh similarity index 76% rename from csrc/paged_attention/attention/dtype_float32.cuh rename to csrc/attention/dtype_float32.cuh index b200d2d..7c6a686 100644 --- a/csrc/paged_attention/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -38,37 +40,35 @@ struct Float8_ { }; // FP32 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = float; }; -template<> +template <> struct Vec { using Type = float2; }; -template<> +template <> struct Vec { using Type = float4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = float4; }; // Vector addition. -inline __device__ float add(float a, float b) { - return a + b; -} +inline __device__ float add(float a, float b) { return a + b; } inline __device__ float2 add(float2 a, float2 b) { float2 c; @@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { } // Vector multiplication. -template<> +template <> inline __device__ float mul(float a, float b) { return a * b; } -template<> +template <> inline __device__ float2 mul(float2 a, float2 b) { float2 c; c.x = a.x * b.x; @@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { return c; } -template<> +template <> inline __device__ float2 mul(float a, float2 b) { float2 c; c.x = a * b.x; @@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { return c; } -template<> +template <> inline __device__ float4 mul(float4 a, float4 b) { float4 c; c.x = a.x * b.x; @@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { return c; } -template<> +template <> inline __device__ float4 mul(float a, float4 b) { float4 c; c.x = a * b.x; @@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) { } // Vector fused multiply-add. -inline __device__ float fma(float a, float b, float c) { - return a * b + c; -} +inline __device__ float fma(float a, float b, float c) { return a * b + c; } inline __device__ float2 fma(float2 a, float2 b, float2 c) { float2 d; @@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { } // Vector sum. -template<> +template <> inline __device__ float sum(float v) { return v; } -template<> +template <> inline __device__ float sum(float2 v) { return v.x + v.y; } -template<> +template <> inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } -template<> +template <> inline __device__ float sum(Float4_ v) { return v.x.x + v.x.y + v.y.x + v.y.y; } -template<> +template <> inline __device__ float sum(Float8_ v) { return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; } // Vector dot product. -inline __device__ float dot(float a, float b) { - return a * b; -} +inline __device__ float dot(float a, float b) { return a * b; } inline __device__ float dot(float2 a, float2 b) { float2 c = mul(a, b); @@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) { } // From float to float. -inline __device__ void from_float(float& dst, float src) { - dst = src; -} +inline __device__ void from_float(float& dst, float src) { dst = src; } -inline __device__ void from_float(float2& dst, float2 src) { - dst = src; -} +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } -inline __device__ void from_float(float4& dst, float4 src) { - dst = src; -} +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } // From float to float. -inline __device__ float to_float(float u) { - return u; -} +inline __device__ float to_float(float u) { return u; } -inline __device__ float2 to_float(float2 u) { - return u; -} +inline __device__ float2 to_float(float2 u) { return u; } -inline __device__ float4 to_float(float4 u) { - return u; -} +inline __device__ float4 to_float(float4 u) { return u; } -inline __device__ Float4_ to_float(Float4_ u) { - return u; -} +inline __device__ Float4_ to_float(Float4_ u) { return u; } -inline __device__ Float8_ to_float(Float8_ u) { - return u; -} +inline __device__ Float8_ to_float(Float8_ u) { return u; } // Zero-out a variable. -inline __device__ void zero(float& dst) { - dst = 0.f; -} +inline __device__ void zero(float& dst) { dst = 0.f; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh new file mode 100644 index 0000000..e714e32 --- /dev/null +++ b/csrc/attention/dtype_fp8.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include "attention_generic.cuh" + +#include +#ifdef ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 + +namespace vllm { + +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache +template <> +struct Vec { + using Type = uint8_t; +}; + +template <> +struct Vec { + using Type = uint16_t; +}; + +template <> +struct Vec { + using Type = uint32_t; +}; + +template <> +struct Vec { + using Type = uint2; +}; + +} // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h new file mode 100644 index 0000000..11c4c50 --- /dev/null +++ b/csrc/cache.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +#include +#include + +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping); + +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& block_mapping); + +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale); + +void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + const double k_scale, const double v_scale); + +// Just for unittest +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const double scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu new file mode 100644 index 0000000..1be806b --- /dev/null +++ b/csrc/cache_kernels.cu @@ -0,0 +1,405 @@ +#include +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#ifdef USE_ROCM + #include "quantization/fp8/amd/quant_utils.cuh" +#else + #include "quantization/fp8/nvidia/quant_utils.cuh" +#endif + +#include +#include +#include +#include + +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 __nv_bfloat16; +#endif + +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + cudaMemcpyKind memcpy_type; + if (src_device.is_cuda() && dst_device.is_cuda()) { + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + memcpy_type = cudaMemcpyDeviceToDevice; + } else if (src_device.is_cuda() && dst_device.is_cpu()) { + memcpy_type = cudaMemcpyDeviceToHost; + } else if (src_device.is_cpu() && dst_device.is_cuda()) { + memcpy_type = cudaMemcpyHostToDevice; + } else { + TORCH_CHECK(false, "Invalid device combination"); + } + + // NOTE(youkaichao): keep in mind that `block_mapping` should be + // a cpu tensor, otherwise every `item` call will require a gpu-cpu + // synchronization. + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + char* src_ptr = static_cast(src.data_ptr()); + char* dst_ptr = static_cast(dst.data_ptr()); + + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + const at::cuda::OptionalCUDAGuard device_guard( + src_device.is_cuda() ? src_device : dst_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // NOTE(woosuk): This can be slow if the number of blocks is large. + const int64_t num_blocks = block_mapping.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; + cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, + block_size_in_bytes, memcpy_type, stream); + } +} + +namespace vllm { + +// Grid: (num_layers, num_pairs) +template +__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int64_t* __restrict__ block_mapping, + const int numel_per_block) { + const int layer_idx = blockIdx.x; + const int pair_idx = blockIdx.y; + + scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); + scalar_t* value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); + int64_t src_block_number = block_mapping[2 * pair_idx]; + int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; + + const int64_t src_block_offset = src_block_number * numel_per_block; + const int64_t dst_block_offset = dst_block_number * numel_per_block; + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + key_cache[dst_offset] = key_cache[src_offset]; + } + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + value_cache[dst_offset] = value_cache[src_offset]; + } +} + +} // namespace vllm + +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& block_mapping) { + int num_layers = key_caches.size(); + TORCH_CHECK(num_layers == value_caches.size()); + if (num_layers == 0) { + return; + } + torch::Device cache_device = key_caches[0].device(); + TORCH_CHECK(cache_device.is_cuda()); + + // Create data structures for the kernel. + // Create an array of pointers to the key and value caches. + int64_t key_cache_ptrs[num_layers]; + int64_t value_cache_ptrs[num_layers]; + for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + key_cache_ptrs[layer_idx] = + reinterpret_cast(key_caches[layer_idx].data_ptr()); + value_cache_ptrs[layer_idx] = + reinterpret_cast(value_caches[layer_idx].data_ptr()); + } + + // block_mapping is a 2D tensor with shape (num_pairs, 2). + int num_pairs = block_mapping.size(0); + + // Move the data structures to the GPU. + // NOTE: This synchronizes the CPU and GPU. + torch::Tensor key_cache_ptrs_tensor = + torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); + torch::Tensor value_cache_ptrs_tensor = + torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); + + // Launch the kernel. + const int numel_per_block = key_caches[0][0].numel(); + dim3 grid(num_layers, num_pairs); + dim3 block(std::min(1024, numel_per_block)); + const at::cuda::OptionalCUDAGuard device_guard(cache_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( + key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { + vllm::copy_blocks_kernel<<>>( + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); + })); +} + +namespace vllm { + +template +__global__ void reshape_and_cache_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, + // block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, + // block_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x, const float k_scale, + const float v_scale) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + key_cache[tgt_key_idx] = tgt_key; + value_cache[tgt_value_idx] = tgt_value; + } else { + key_cache[tgt_key_idx] = + fp8::scaled_convert(tgt_key, k_scale); + value_cache[tgt_value_idx] = + fp8::scaled_convert(tgt_value, v_scale); + } + } +} + +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, + // head_size] + cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, + // head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, const int key_stride, const int value_stride, + const int num_heads, const int head_size, const int block_size, + const float k_scale, const float v_scale) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int64_t tgt_key_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + key_cache[tgt_key_value_idx] = tgt_key; + value_cache[tgt_key_value_idx] = tgt_value; + } else { + key_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_key, k_scale); + value_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_value, v_scale); + } + } +} +} // namespace vllm + +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), key_stride, value_stride, \ + num_heads, head_size, block_size, x, k_scale, v_scale); + +void reshape_and_cache( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE) +} + +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, key_stride, \ + value_stride, num_heads, head_size, block_size, k_scale, v_scale); + +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& + value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int block_stride = key_cache.stride(0); + TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE_FLASH); +} + +namespace vllm { + +template +__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, + Tout* __restrict__ dst_cache, + const float scale, + const int64_t block_stride) { + const int64_t block_idx = blockIdx.x; + for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { + int64_t idx = block_idx * block_stride + i; + dst_cache[idx] = + fp8::scaled_convert(src_cache[idx], scale); + } +} + +} // namespace vllm + +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), scale, block_stride); + +// Only for testing. +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const double scale, const std::string& kv_cache_dtype) { + torch::Device src_device = src_cache.device(); + torch::Device dst_device = dst_cache.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + + int64_t num_blocks = src_cache.size(0); + int64_t block_stride = src_cache.stride(0); + + dim3 grid(num_blocks); + dim3 block(std::min(block_stride, int64_t(512))); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (kv_cache_dtype == "auto") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } + } else { + TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); + } +} diff --git a/csrc/core/registration.h b/csrc/core/registration.h new file mode 100644 index 0000000..e5396e9 --- /dev/null +++ b/csrc/core/registration.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp new file mode 100644 index 0000000..b1e10fe --- /dev/null +++ b/csrc/core/scalar_type.hpp @@ -0,0 +1,547 @@ +#pragma once + +#include + +namespace vllm { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// ScalarTypeTorch is a subclass of ScalarType that is compatible with +// TORCH_LIBRARY, making it accessible from Python as well meaning this class +// can be used as a argument for custom operators, helping to simplify these +// interfaces. +// +// The type definitions on the Python side can be found in: vllm/_core_ext.pyi +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, + int32_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr){}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, + uint8_t mantissa) { + TORCH_CHECK(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + TORCH_CHECK(mantissa > 0 && exponent > 0); + TORCH_CHECK(nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, + nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, + Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, + finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { + return acc + member_id_field_width(); + }, + 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, + "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, + auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) + << bit_offset, + bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & + ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, + std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, + tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { return signed_; } + constexpr bool is_integer() const { return exponent == 0; } + constexpr bool is_floating_point() const { return exponent > 0; } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && + nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { return bias != 0; } + + private: + double _floating_point_max() const { + TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = + max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = + (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + TORCH_CHECK(is_signed(), + "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && + bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; + } +}; + +// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from +// torch::CustomClassHolder), we use multiple inheritance here since we cannot +// have ScalarType inherit from torch::CustomClassHolder and have a constexpr +// constructor at the same time (torch::CustomClassHolder does not have a +// constexpr destructor) +// See also: +// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA +class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { + public: + ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias, + bool _signed) + : ScalarType(exponent, mantissa, bias, _signed){}; + + ScalarTypeTorch(ScalarType type) : ScalarType(type){}; + + using Base = ScalarType; + using Self = ScalarTypeTorch; + using SelfPtr = c10::intrusive_ptr; + + static void check_size_bits(int64_t size_bits, bool signed_) { + TORCH_CHECK( + size_bits <= + std::numeric_limits().mantissa)>::max(), + "size_bits bit width is too large to be represented"); + } + + static void check_bias(int64_t bias) { + using Bias = decltype(std::declval().bias); + TORCH_CHECK(bias <= std::numeric_limits::max() && + bias >= std::numeric_limits::min(), + "bias too large or small to be represented"); + } + + static void check_exponent(int64_t exponent) { + TORCH_CHECK( + exponent <= + std::numeric_limits().exponent)>::max(), + "exponent bit width is too large to be represented"); + } + + static void check_mantissa(int64_t mantissa) { + TORCH_CHECK( + mantissa <= + std::numeric_limits().mantissa)>::max(), + "mantissa bit width is too large to be represented"); + } + + static SelfPtr int_(int64_t size_bits, c10::optional bias) { + check_size_bits(size_bits, true); + check_bias(bias.value_or(0)); + return c10::make_intrusive( + ScalarType::int_(size_bits, bias.value_or(0))); + } + + static SelfPtr uint(int64_t size_bits, c10::optional bias) { + check_size_bits(size_bits, true); + check_bias(bias.value_or(0)); + return c10::make_intrusive( + ScalarType::uint(size_bits, bias.value_or(0))); + } + + static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) { + check_mantissa(mantissa); + check_exponent(exponent); + return c10::make_intrusive( + ScalarType::float_IEEE754(exponent, mantissa)); + } + + static SelfPtr float_(int64_t exponent, int64_t mantissa, + bool finite_values_only, int64_t nan_repr) { + check_mantissa(mantissa); + check_exponent(exponent); + return c10::make_intrusive(ScalarType::float_( + exponent, mantissa, finite_values_only, NanRepr(nan_repr))); + } + + // This needs to be implemented and throw a TypeError in order for + // PyTorch's opcheck to work on ops that use ScalarTypes. + int64_t len() const { + throw c10::TypeError("__len__ not implemented"); + return 0; + } + + // Serialize a ScalarType into a tuple of pairs. Where each pair + // is a (fieldname, value). + // For simplicity, we are just going to convert to a ScalarTypeId. + std::tuple> obj_flatten() const { + return {{"ScalarType", id()}}; + } + + // Deserialize a scalar type that has been serialized by obj_flatten, + // ostensibly from a tuple of (member name, value) pairs, but in reality + // just a ScalarTypeId. + static SelfPtr obj_unflatten( + std::tuple> const& flat_type) { + return c10::make_intrusive( + from_id(std::get<1>(std::get<0>(flat_type)))); + } + + template + static void bind_readonly_property(torch::class_& cls, + std::string const& name, T Base::*field) { + auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) { + if constexpr (std::is_member_function_pointer_v) { + return (self.get()->*field)(); + } else { + return self.get()->*field; + } + }; + + auto getter_func = [field = std::move(field), + getter_func_helper = std::move(getter_func_helper)]( + SelfPtr const& self) { + auto val = getter_func_helper(self); + // upconvert uint8_t, int32_t etc. to int64_t for python + if constexpr (std::is_integral_v) { + return static_cast(val); + } else { + return val; + } + }; + + cls.def_property(name, getter_func); + } + + template + static void bind_function(torch::class_& cls, const std::string& name, + MemberFunc Cls::*member) { + cls.def(name, [member = std::move(member)](SelfPtr const& self) { + return (self.get()->*member)(); + }); + } + + template + static void bind_function(torch::class_& cls, const std::string& name, + Func func) { + cls.def(name, func); + } + + template + static void bind_static_function(torch::class_& cls, + const std::string& name, Func func) { + cls.def_static(name, func); + } + + static void bind_class(torch::Library& lib) { + auto cls = lib.class_("ScalarType") + .def(torch::init()); + + // Bind Properties + bind_readonly_property(cls, "mantissa", &Base::mantissa); + bind_readonly_property(cls, "exponent", &Base::exponent); + bind_readonly_property(cls, "bias", &Base::bias); + bind_readonly_property(cls, "signed", &Base::is_signed); + bind_readonly_property(cls, "size_bits", &Base::size_bits); + + // Bind member functions + bind_function(cls, "is_signed", &Base::is_signed); + bind_function(cls, "is_integer", &Base::is_integer); + bind_function(cls, "is_floating_point", &Base::is_floating_point); + bind_function(cls, "is_ieee_754", &Base::is_ieee_754); + bind_function(cls, "has_nans", &Base::has_nans); + bind_function(cls, "has_infs", &Base::has_infs); + bind_function(cls, "has_bias", &Base::has_bias); + + bind_function(cls, "max", [](SelfPtr const& self) { + return std::visit([](auto arg) { return c10::IValue(arg); }, + self.get()->max()); + }); + bind_function(cls, "min", [](SelfPtr const& self) { + return std::visit([](auto arg) { return c10::IValue(arg); }, + self.get()->min()); + }); + + bind_function(cls, "__len__", &ScalarTypeTorch::len); + bind_function(cls, "__str__", &Base::str); + bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) { + return *self == *other; + }); + bind_function(cls, "__repr__", [](SelfPtr const& self) { + return "ScalarType." + self.get()->str(); + }); + + bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten); + bind_static_function(cls, "__obj_unflatten__", + &ScalarTypeTorch::obj_unflatten); + + // Bind static functions (convenience constructors) + bind_static_function(cls, "int_", &ScalarTypeTorch::int_); + bind_static_function(cls, "uint", &ScalarTypeTorch::uint); + bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754); + bind_static_function(cls, "float_", &ScalarTypeTorch::float_); + } +}; + +using ScalarTypeId = int64_t; +using ScalarTypeTorchPtr = c10::intrusive_ptr; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE3M2f = + ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = + ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +}; // namespace vllm diff --git a/csrc/core/torch_bindings.cpp b/csrc/core/torch_bindings.cpp new file mode 100644 index 0000000..f602541 --- /dev/null +++ b/csrc/core/torch_bindings.cpp @@ -0,0 +1,16 @@ +#include + +#include "scalar_type.hpp" +#include "registration.h" + +// Note the CORE exstension will be built for (almost) all hardware targets so +// new additions must account for this. (currently not built for TPU and Neuron) + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) { + // ScalarType, a custom class for representing data types that supports + // quantized types, declared here so it can be used when creating interfaces + // for custom ops. + vllm::ScalarTypeTorch::bind_class(lib); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp new file mode 100644 index 0000000..039b8d5 --- /dev/null +++ b/csrc/cpu/activation.cpp @@ -0,0 +1,163 @@ +#include "cpu_types.hpp" + +namespace { +template +void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input, + scalar_t* __restrict__ output) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + + TORCH_CHECK(d % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + for (int j = 0; j < d; j += VEC_ELEM_NUM) { + int start = i * d; + if constexpr (is_gated) { + start *= 2; + } + + const scalar_vec_t x(input + start + j); + const vec_op::FP32Vec8 f32_x(x); + vec_op::FP32Vec8 f32_ans = func(f32_x); + + if constexpr (is_gated) { + const scalar_vec_t y(input + start + d + j); + const vec_op::FP32Vec8 f32_y(y); + f32_ans = f32_y * f32_ans; + } + + const scalar_vec_t result(f32_ans); + result.save(output + i * d + j); + } + } +} + +FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 zeros(0.0); + const vec_op::FP32Vec8 ones(1.0); + return x / (ones + (zeros - x).exp()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); + const vec_op::FP32Vec8 w3(0.5); + const vec_op::FP32Vec8 x3 = x * x * x; + const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh(); + return w3 * x * (ones + t); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); + const vec_op::FP32Vec8 w3(0.5); + const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh(); + return w3 * x * (ones + t); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 zeros(0.0); + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(1.702f); + return x / (ones + (zeros - w1 * x).exp()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT1_2); + const vec_op::FP32Vec8 w2(0.5); + return x * w2 * (ones + (x * w1).er()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); + const vec_op::FP32Vec8 w2(0.5); + const vec_op::FP32Vec8 w3(0.044715); + const vec_op::FP32Vec8 x_3 = x * x * x; + const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); + return x * w2 * (ones + inner.tanh()); +} +}; // namespace + +void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) + }); +} + +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) + }); +} + +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "gelu_tanh_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl) + }); +} + +void gelu_new(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_new_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_new_impl) + }); +} + +void gelu_fast(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_fast_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_fast_impl) + }); +} + +void gelu_quick(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_quick_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_quick_impl) + }); +} diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp new file mode 100644 index 0000000..abb4e3b --- /dev/null +++ b/csrc/cpu/attention.cpp @@ -0,0 +1,758 @@ +#include "cpu_types.hpp" + +namespace { + +template +struct KernelVecType { + using q_load_vec_type = void; + using q_vec_type = void; + using k_load_vec_type = void; + using k_vec_type = void; + using qk_acc_vec_type = void; + using v_load_vec_type = void; +}; + +template <> +struct KernelVecType { + using q_load_vec_type = vec_op::FP32Vec4; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::FP32Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP32Vec16; +}; + +#ifdef __AVX512BF16__ +template <> +struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::BF16Vec32; + using k_load_vec_type = vec_op::BF16Vec32; + using k_vec_type = vec_op::BF16Vec32; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; +#else +template <> +struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::BF16Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; +#endif + +template +FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, + const int capacity) { + T max = data[0]; + for (int i = 1; i < size; ++i) { + max = max >= data[i] ? max : data[i]; + } + + T sum = 0; + for (int i = 0; i < size; ++i) { + data[i] = std::exp(data[i] - max); + sum += data[i]; + } + + int i = 0; + for (; i < size; ++i) { + data[i] /= sum; + } + + for (; i < capacity; ++i) { + data[i] = 0; + } + + return {max, sum}; +} + +template +FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, + const int capacity, + const float alibi_slope, + const int start_index, + const int seq_len) { + data[0] += alibi_slope * (start_index - seq_len + 1); + T max = data[0]; + for (int i = 1; i < size; ++i) { + T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1); + data[i] = qk; + max = max >= qk ? max : qk; + } + + T sum = 0; + for (int i = 0; i < size; ++i) { + data[i] = std::exp(data[i] - max); + sum += data[i]; + } + + int i = 0; + for (; i < size; ++i) { + data[i] /= sum; + } + + for (; i < capacity; ++i) { + data[i] = 0; + } + + return {max, sum}; +} + +template +FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, + const int size) { + T max = max_data[0]; + for (int i = 1; i < size; ++i) { + max = max >= max_data[i] ? max : max_data[i]; + } + + T rescaled_sum = 0; + for (int i = 0; i < size; ++i) { + T rescale_factor = std::exp(max_data[i] - max); + rescaled_sum += rescale_factor * sum_data[i]; + sum_data[i] *= rescale_factor; + } + for (int i = 0; i < size; ++i) { + sum_data[i] /= rescaled_sum + 1e-8; + } +} + +template +struct reduceQKBlockKernel { + using q_load_vec_type = typename KernelVecType::q_load_vec_type; + using q_vec_type = typename KernelVecType::q_vec_type; + using k_load_vec_type = typename KernelVecType::k_load_vec_type; + using k_vec_type = typename KernelVecType::k_vec_type; + using qk_acc_vec_type = typename KernelVecType::qk_acc_vec_type; + + constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x; + constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP; + constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4; + + static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4); + static_assert(k_load_vec_type::get_elem_num() % x == 0); + static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); + + FORCE_INLINE static void call(const scalar_t* __restrict__ q, + const scalar_t* __restrict__ k_block, + float* __restrict__ logits, float scale, + const int token_num) { + const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; + + qk_acc_vec_type group_accums[MAX_GROUP_NUM]; + if (token_num == BLOCK_SIZE) { + for (int q_offset = 0; q_offset < HEAD_SIZE; + q_offset += x, k_block += x * BLOCK_SIZE) { + q_load_vec_type q_load_group_vec(q + q_offset); + q_vec_type q_group_vec(q_load_group_vec); + + vec_op::unroll_loop( + [k_block, &q_group_vec, &group_accums](int token_group_idx) { + k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * + TOKEN_PER_GROUP); + k_vec_type k_group_vec(k_load_group_vec); + vec_op::fma(group_accums[token_group_idx], q_group_vec, + k_group_vec); + vec_op::prefetch(k_block + x * BLOCK_SIZE + + token_group_idx * x * TOKEN_PER_GROUP); + }); + } + } else { + for (int q_offset = 0; q_offset < HEAD_SIZE; + q_offset += x, k_block += x * BLOCK_SIZE) { + q_load_vec_type q_load_group_vec(q + q_offset); + q_vec_type q_group_vec(q_load_group_vec); + for (int token_group_start = 0; token_group_start < group_num; + token_group_start += UNROLL_GROUP_NUM) { + vec_op::unroll_loop( + [token_group_start, k_block, &q_group_vec, + &group_accums](int token_group_idx) { + token_group_idx += token_group_start; + k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * + TOKEN_PER_GROUP); + k_vec_type k_group_vec(k_load_group_vec); + vec_op::fma(group_accums[token_group_idx], q_group_vec, + k_group_vec); + vec_op::prefetch(k_block + x * BLOCK_SIZE + + token_group_idx * x * TOKEN_PER_GROUP); + }); + } + } + } + + for (int token_group_idx = 0; token_group_idx < group_num; + ++token_group_idx) { + vec_op::unroll_loop( + [&group_accums, logits, scale, token_group_idx](int token_idx) { + float dot_v = + group_accums[token_group_idx] + .template reduce_sub_sum(token_idx); + logits[token_group_idx * TOKEN_PER_GROUP + token_idx] = + dot_v * scale; + }); + } + } +}; + +template +FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, + acc_t&& acc) { + using v_load_vec_type = typename KernelVecType::v_load_vec_type; + constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); + static_assert(BLOCK_SIZE == ELEM_NUM); + vec_op::FP32Vec16 prob_vec(prob); + + vec_op::unroll_loop([&](int head_elem_idx) { + v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx); + vec_op::FP32Vec16 fp32_v_vec(v_vec); + acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; + }); +} +}; // namespace + +// Paged attention v1 +namespace { +template +struct paged_attention_v1_impl { + static void call( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { + constexpr int x = 16 / sizeof(scalar_t); + const int num_queries_per_kv = num_heads / num_kv_heads; + + static_assert(BLOCK_SIZE == 16); + + int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE; + int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0); + + const int parallel_work_item_num = omp_get_max_threads(); + + size_t logits_bytes = + parallel_work_item_num * max_seq_len_padded * sizeof(float); + float* logits = (float*)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [parallel_work_item_num, max_seq_len_padded] + +#pragma omp parallel for collapse(2) schedule(dynamic, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + int seq_len = seq_lens[seq_idx]; + const int* seq_block_table = + block_tables + max_num_blocks_per_seq * seq_idx; + const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int64_t kv_head_idx = head_idx / num_queries_per_kv; + const scalar_t* __restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; + const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; + float* __restrict__ thread_block_logits = + logits + omp_get_thread_num() * max_seq_len_padded; + + // Compute logits + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const scalar_t* __restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; + float* __restrict__ head_block_logits = + thread_block_logits + block_idx * BLOCK_SIZE; + + reduceQKBlockKernel::call( + q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, + block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); + } + + // Compute softmax + if (alibi_slopes) { + reduceSoftmaxAlibi(thread_block_logits, seq_len, + block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, + seq_len); + } else { + reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); + } + + // Compute value + constexpr int head_elem_num_per_partition = 16; + constexpr int head_partition_num = + HEAD_SIZE / head_elem_num_per_partition; + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + vec_op::FP32Vec16 accums[head_elem_num_per_partition]; + scalar_t* __restrict__ out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + head_part_idx * head_elem_num_per_partition; + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const float* __restrict__ prob_vec_ptr = + thread_block_logits + block_idx * BLOCK_SIZE; + const scalar_t* __restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + reduceValueBlock( + prob_vec_ptr, v_block_cache_ptr, accums); + + if (block_idx != block_num - 1) { + const int64_t next_physical_block_idx = + seq_block_table[block_idx + 1]; + const scalar_t* __restrict__ next_v_block_cache_ptr = + v_cache + next_physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + vec_op::unroll_loop( + [&](int head_elem_idx) { + if (head_elem_idx % 2 == 0) { + vec_op::prefetch(next_v_block_cache_ptr + + BLOCK_SIZE * head_elem_idx); + } + }); + } + } + + vec_op::unroll_loop( + [&](int head_elem_idx) { + float value = accums[head_elem_idx].reduce_sum(); + vec_op::storeFP32(value, out_ptr + head_elem_idx); + }); + } + } + } + std::free(logits); + } +}; + +#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v1_impl::call( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ + num_heads); + +template +void paged_attention_v1_impl_launcher( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + + switch (head_size) { + case 64: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 192: + LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); + break; + case 256: + LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes); + +#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V1_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } +} // namespace + +void paged_attention_v1( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); + TORCH_CHECK(blocksparse_vert_stride <= 1, + "CPU backend does not support blocksparse attention yet."); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", + [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) + CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); + CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl) + }); +} + +// Paged attention v2 +namespace { +template +struct paged_attention_v2_impl { + static void call( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads, const int max_num_partitions) { + constexpr int x = 16 / sizeof(scalar_t); + const int num_queries_per_kv = num_heads / num_kv_heads; + + static_assert(BLOCK_SIZE == 16); + static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0); + static_assert(PARTITION_SIZE % BLOCK_SIZE == 0); + +#pragma omp parallel for collapse(3) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int partition_idx = 0; partition_idx < max_num_partitions; + ++partition_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int seq_len = seq_lens[seq_idx]; + const int start_token_idx = partition_idx * PARTITION_SIZE; + + if (start_token_idx >= seq_len) continue; + + const int partition_num = + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + const bool no_reduce = (partition_num == 1); + const int token_num = + (std::min(seq_len, start_token_idx + PARTITION_SIZE) - + start_token_idx); + const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int last_block_token_num = + token_num - (block_num - 1) * BLOCK_SIZE; + const int* seq_block_table = block_tables + + max_num_blocks_per_seq * seq_idx + + start_token_idx / BLOCK_SIZE; + const int64_t kv_head_idx = head_idx / num_queries_per_kv; + const scalar_t* __restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; + + float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; + + // Compute logits + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const scalar_t* __restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; + float* __restrict__ head_block_logits = + logits + block_idx * BLOCK_SIZE; + + reduceQKBlockKernel::call( + q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, + block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); + } + + std::pair max_and_sum; + if (alibi_slopes) { + max_and_sum = reduceSoftmaxAlibi( + logits, token_num, block_num * BLOCK_SIZE, + alibi_slopes[head_idx], start_token_idx, seq_len); + } else { + max_and_sum = + reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); + } + + auto&& [max_logit, exp_sum] = max_and_sum; + + scalar_t* __restrict__ output_buffer = nullptr; + if (!no_reduce) { + auto idx = seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + max_logits[idx] = max_logit; + exp_sums[idx] = exp_sum; + output_buffer = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + output_buffer = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + } + + // Compute value + constexpr int head_elem_num_per_partition = 16; + constexpr int head_partition_num = + HEAD_SIZE / head_elem_num_per_partition; + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + vec_op::FP32Vec16 accums[head_elem_num_per_partition]; + scalar_t* __restrict__ out_ptr = + output_buffer + head_part_idx * head_elem_num_per_partition; + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const float* __restrict__ prob_vec_ptr = + logits + block_idx * BLOCK_SIZE; + const scalar_t* __restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + reduceValueBlock( + prob_vec_ptr, v_block_cache_ptr, accums); + + if (block_idx != block_num - 1) { + const int64_t next_physical_block_idx = + seq_block_table[block_idx + 1]; + const scalar_t* __restrict__ next_v_block_cache_ptr = + v_cache + next_physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + vec_op::unroll_loop( + [&](int head_elem_idx) { + if (head_elem_idx % 2 == 0) { + vec_op::prefetch(next_v_block_cache_ptr + + BLOCK_SIZE * head_elem_idx); + } + }); + } + } + + vec_op::unroll_loop( + [&](int head_elem_idx) { + float value = accums[head_elem_idx].reduce_sum(); + vec_op::storeFP32(value, out_ptr + head_elem_idx); + }); + } + } + } + } + + // Rescale partition softmax and store the factors to exp_sums +#pragma omp parallel for collapse(2) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int seq_len = seq_lens[seq_idx]; + const int partition_num = + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + + if (partition_num == 1) continue; + + reducePartitonSoftmax( + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions, + exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions, + partition_num); + } + } + + // Reduce values + using v_load_vec_type = typename KernelVecType::v_load_vec_type; + static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); + constexpr int head_elem_num_per_group = + 16; // Note: didn't align with the cacheline size, due to some + // HEAD_SIZE didn't align with 64 bytes + static_assert(HEAD_SIZE % head_elem_num_per_group == 0); + constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; + const float* __restrict__ rescale_factors = exp_sums; +#pragma omp parallel for collapse(3) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { + const int seq_len = seq_lens[seq_idx]; + const int partition_num = + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + + if (partition_num == 1) continue; + + const float* __restrict__ seq_head_rescale_factors = + rescale_factors + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + const scalar_t* __restrict__ seq_head_tmp_out = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + group_idx * head_elem_num_per_group; + scalar_t* __restrict__ seq_head_output = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + group_idx * head_elem_num_per_group; + + vec_op::FP32Vec16 acc; + for (int i = 0; i < partition_num; ++i) { + vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]); + v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE); + vec_op::FP32Vec16 fp32_value(value); + acc = acc + fp32_value * rescale_factor; + } + v_load_vec_type cast_acc(acc); + cast_acc.save(seq_head_output); + } + } + } + } +}; + +#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_impl::call( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ + key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, num_seqs, num_heads, \ + max_num_partitions); + +template +void paged_attention_v2_impl_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + int max_num_partitions = exp_sums.size(-1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + + switch (head_size) { + case 64: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 192: + LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); + break; + case 256: + LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_impl_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ + alibi_slopes); + +#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V2_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } +} // namespace + +void paged_attention_v2( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); + TORCH_CHECK(blocksparse_vert_stride <= 1, + "CPU backend does not support blocksparse attention yet."); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", + [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) + CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); + CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) + }); +} diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp new file mode 100644 index 0000000..31d4543 --- /dev/null +++ b/csrc/cpu/cache.cpp @@ -0,0 +1,138 @@ +#include +#include + +#include "cpu_types.hpp" + +namespace { +template +void copy_blocks_cpu_impl(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& mapping_pairs, + const int element_num_per_block, + const int layer_num) { + const size_t pair_num = mapping_pairs.size(0); + const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; +#pragma omp parallel for collapse(2) + for (int layer = 0; layer < layer_num; ++layer) { + for (size_t pair = 0; pair < pair_num; ++pair) { + int64_t source_offset = + element_num_per_block * mapping_pairs[pair][0].item(); + int64_t target_offset = + element_num_per_block * mapping_pairs[pair][1].item(); + scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); + scalar_t* source_ptr = key_cache_ptr + source_offset; + scalar_t* target_ptr = key_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); + + scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); + source_ptr = value_cache_ptr + source_offset; + target_ptr = value_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); + } + } +} + +template +void reshape_and_cache_cpu_impl( + const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, + const int64_t* __restrict__ slot_mapping, const int num_tokens, + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x) { + const int block_elem_num = num_heads * head_size * block_size; + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx >= 0) { + int src_key_head_idx = token_idx * key_stride + head_idx * head_size; + int src_value_head_idx = + token_idx * value_stride + head_idx * head_size; + const scalar_t* src_key_head_ptr = key + src_key_head_idx; + const scalar_t* src_value_head_ptr = value + src_value_head_idx; + const int64_t block_index = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + scalar_t* target_key_head_ptr = key_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; + scalar_t* target_value_head_ptr = value_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; + + for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) { + const int64_t target_offset = + src_key_idx * block_size + block_offset * x; + for (int i = 0; i < x; ++i) { + target_key_head_ptr[target_offset + i] = + src_key_head_ptr[src_key_idx + i]; + } + } + + for (int src_value_idx = 0; src_value_idx < head_size; + ++src_value_idx) { + const int64_t target_offset = + src_value_idx * block_size + block_offset; + target_value_head_ptr[target_offset] = + src_value_head_ptr[src_value_idx]; + } + } + } + } +} +}; // namespace + +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& block_mapping) { + unsigned num_layers = key_caches.size(); + TORCH_CHECK(num_layers == value_caches.size()); + if (num_layers == 0) { + return; + } + + const int element_num_per_block = key_caches[0][0].numel(); + VLLM_DISPATCH_FLOATING_TYPES( + key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) + copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, + element_num_per_block, num_layers); + CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) + }); +} + +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, double k_scale, + double v_scale) { + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); + + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) + reshape_and_cache_cpu_impl( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + slot_mapping.data_ptr(), num_tokens, key_stride, + value_stride, num_heads, head_size, block_size, x); + CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) + }); +} + +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { + TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") +} diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp new file mode 100644 index 0000000..0213be0 --- /dev/null +++ b/csrc/cpu/cpu_types.hpp @@ -0,0 +1,15 @@ + +#ifndef CPU_TYPES_HPP +#define CPU_TYPES_HPP + +#if defined(__x86_64__) + //x86 implementation + #include "cpu_types_x86.hpp" +#elif defined(__POWER9_VECTOR__) + //ppc implementation + #include "cpu_types_vsx.hpp" +#else + #warning "unsupported vLLM cpu implementation" +#endif + +#endif diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp new file mode 100644 index 0000000..b50bdad --- /dev/null +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -0,0 +1,491 @@ + +#ifndef CPU_TYPES_VSX_HPP +#define CPU_TYPES_VSX_HPP + +#include +#include +#include + +namespace vec_op { + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +typedef struct ss16x8x2_t { + __vector signed short val[2]; +} ss16x8x2_t; + +typedef struct ss16x8x4_t { + __vector signed short val[4]; +} ss16x8x4_t; + +typedef struct f32x4x2_t { + __vector float val[2]; +} f32x4x2_t; + +typedef struct f32x4x4_t { + __vector float val[4]; +} f32x4x4_t; + +struct FP32Vec8; +struct FP32Vec16; + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __vector signed short reg; + + explicit BF16Vec8(const void *ptr) + : reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + ss16x8x2_t reg; + + explicit BF16Vec16(const void *ptr) { + // Load 256 bits in two parts + reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr); + reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr); + } + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { + // Save 256 bits in two parts + vec_xst(reg.val[0], 0, (signed short *)ptr); + vec_xst(reg.val[1], 16, (signed short *)ptr); + } +}; + +const static __vector signed short zero = vec_splats((signed short)0); + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + ss16x8x4_t reg; + explicit BF16Vec32(const void *ptr) + : reg(*reinterpret_cast(ptr)) {} + + explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} + + explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ + vec8_data.reg, + vec8_data.reg, + vec8_data.reg, + vec8_data.reg + }) {} + + void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __vector float reg; + float values[VEC_ELEM_NUM]; + }; + + __vector float reg; + + explicit FP32Vec4(float v) : reg(vec_splats(v)) {} + + explicit FP32Vec4() : reg(vec_splats(0.0f)) {} + + explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {} + + explicit FP32Vec4(__vector float data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + f32x4x2_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x2_t reg; + + explicit FP32Vec8(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + } + + explicit FP32Vec8() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + } + + explicit FP32Vec8(const float *ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + } + + explicit FP32Vec8(f32x4x2_t data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + } + + explicit FP32Vec8(const BF16Vec8 &v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::exp(ar.values[0]); + ret.val[0][1] = std::exp(ar.values[1]); + ret.val[0][2] = std::exp(ar.values[2]); + ret.val[0][3] = std::exp(ar.values[3]); + ret.val[1][0] = std::exp(ar.values[4]); + ret.val[1][1] = std::exp(ar.values[5]); + ret.val[1][2] = std::exp(ar.values[6]); + ret.val[1][3] = std::exp(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 tanh() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::tanh(ar.values[0]); + ret.val[0][1] = std::tanh(ar.values[1]); + ret.val[0][2] = std::tanh(ar.values[2]); + ret.val[0][3] = std::tanh(ar.values[3]); + ret.val[1][0] = std::tanh(ar.values[4]); + ret.val[1][1] = std::tanh(ar.values[5]); + ret.val[1][2] = std::tanh(ar.values[6]); + ret.val[1][3] = std::tanh(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 er() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::erf(ar.values[0]); + ret.val[0][1] = std::erf(ar.values[1]); + ret.val[0][2] = std::erf(ar.values[2]); + ret.val[0][3] = std::erf(ar.values[3]); + ret.val[1][0] = std::erf(ar.values[4]); + ret.val[1][1] = std::erf(ar.values[5]); + ret.val[1][2] = std::erf(ar.values[6]); + ret.val[1][3] = std::erf(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); + } + + void save(float *ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + f32x4x4_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x4_t reg; + + explicit FP32Vec16(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + reg.val[2] = vec_splats(v); + reg.val[3] = vec_splats(v); + } + + explicit FP32Vec16() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + reg.val[2] = vec_splats(0.0f); + reg.val[3] = vec_splats(0.0f); + } + + explicit FP32Vec16(const float *ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + reg.val[2] = vec_xl(32, ptr); + reg.val[3] = vec_xl(48, ptr); + } + + explicit FP32Vec16(f32x4x4_t data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16 &data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[2]; + reg.val[3] = data.reg.val[3]; + } + + explicit FP32Vec16(const FP32Vec4 &data) { + reg.val[0] = data.reg; + reg.val[1] = data.reg; + reg.val[2] = data.reg; + reg.val[3] = data.reg; + } + + explicit FP32Vec16(const FP32Vec8 &data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[0]; + reg.val[3] = data.reg.val[1]; + } + + explicit FP32Vec16(const BF16Vec16 &v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); + reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); + reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); + } + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_mul(reg.val[0], b.reg.val[0]), + vec_mul(reg.val[1], b.reg.val[1]), + vec_mul(reg.val[2], b.reg.val[2]), + vec_mul(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_add(reg.val[0], b.reg.val[0]), + vec_add(reg.val[1], b.reg.val[1]), + vec_add(reg.val[2], b.reg.val[2]), + vec_add(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_sub(reg.val[0], b.reg.val[0]), + vec_sub(reg.val[1], b.reg.val[1]), + vec_sub(reg.val[2], b.reg.val[2]), + vec_sub(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_div(reg.val[0], b.reg.val[0]), + vec_div(reg.val[1], b.reg.val[1]), + vec_div(reg.val[2], b.reg.val[2]), + vec_div(reg.val[3], b.reg.val[3])})); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + + AliasReg ar; + ar.reg = reg; + float result = 0; + const int start = idx * group_size; + unroll_loop( + [&result, &start, ar](int i) { result += ar.values[start + i]; }); + + return result; + } + + void save(float *ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + vec_xst(reg.val[2], 32, ptr); + vec_xst(reg.val[3], 48, ptr); + } +}; + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { + acc = acc + a * b; +} + +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +#ifndef __VEC_CLASS_FP_NAN +#define __VEC_CLASS_FP_NAN (1 << 6) +#endif + +const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 }; +#ifndef _ARCH_PWR10 +const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff }; +const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 }; +const static __vector unsigned int sh16 = { 16, 16, 16, 16 }; +const static __vector unsigned int one = { 1, 1, 1, 1 }; +#endif + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { +#ifdef _ARCH_PWR10 + __vector signed short ret[2]; + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); + reg = vec_perm(ret[0], ret[1], omask); +#elif defined(_ARCH_PWR9) + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + __vector unsigned int lsb0 = vec_sr(inp0, sh16); + __vector unsigned int lsb1 = vec_sr(inp1, sh16); + lsb0 = vec_and(lsb0, one); + lsb1 = vec_and(lsb1, one); + __vector unsigned int rnd0 = vec_add(lsb0, bias); + __vector unsigned int rnd1 = vec_add(lsb1, bias); + inp0 = vec_add(inp0, rnd0); + inp1 = vec_add(inp1, rnd1); + __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + inp0 = vec_sel(inp0, nan, sel0); + inp1 = vec_sel(inp1, nan, sel1); + inp0 = vec_sr(inp0, sh16); + inp1 = vec_sr(inp1, sh16); + reg = (__vector signed short)vec_perm(inp0, inp1, omask); +#endif +} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { +#ifdef _ARCH_PWR10 + __vector signed short ret[4]; + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); + ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]); + ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]); + reg.val[0] = vec_perm(ret[0], ret[1], omask); + reg.val[1] = vec_perm(ret[2], ret[3], omask); +#elif defined(_ARCH_PWR9) + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + __vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]); + __vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]); + __vector unsigned int lsb0 = vec_sr(inp0, sh16); + __vector unsigned int lsb1 = vec_sr(inp1, sh16); + __vector unsigned int lsb2 = vec_sr(inp2, sh16); + __vector unsigned int lsb3 = vec_sr(inp3, sh16); + lsb0 = vec_and(lsb0, one); + lsb1 = vec_and(lsb1, one); + lsb2 = vec_and(lsb2, one); + lsb3 = vec_and(lsb3, one); + __vector unsigned int rnd0 = vec_add(lsb0, bias); + __vector unsigned int rnd1 = vec_add(lsb1, bias); + __vector unsigned int rnd2 = vec_add(lsb2, bias); + __vector unsigned int rnd3 = vec_add(lsb3, bias); + inp0 = vec_add(inp0, rnd0); + inp1 = vec_add(inp1, rnd1); + inp2 = vec_add(inp2, rnd2); + inp3 = vec_add(inp3, rnd3); + __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + __vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); + __vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); + inp0 = vec_sel(inp0, nan, sel0); + inp1 = vec_sel(inp1, nan, sel1); + inp2 = vec_sel(inp2, nan, sel2); + inp3 = vec_sel(inp3, nan, sel3); + inp0 = vec_sr(inp0, sh16); + inp1 = vec_sr(inp1, sh16); + inp2 = vec_sr(inp2, sh16); + inp3 = vec_sr(inp3, sh16); + reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask); + reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask); +#endif +} + +inline void prefetch(const void *addr) { + __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory"); +} + +}; // namespace vec_op + +#endif diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp new file mode 100644 index 0000000..f50620a --- /dev/null +++ b/csrc/cpu/cpu_types_x86.hpp @@ -0,0 +1,515 @@ + +#ifndef CPU_TYPES_X86_HPP +#define CPU_TYPES_X86_HPP + +#include +#include + +#ifndef __AVX2__ +static_assert(false, "AVX2 must be supported for the current implementation."); +#endif + +namespace vec_op { + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +struct FP32Vec8; +struct FP32Vec16; + +#ifdef __AVX512FP16__ +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128h reg; + + explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} + + explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} + + explicit FP16Vec8(__m128h data) : reg(data) {} + + FP16Vec8 operator*(const FP16Vec8 &b) const { + return FP16Vec8(_mm_mul_ph(reg, b.reg)); + } + + FP16Vec8 operator+(const FP16Vec8 &b) const { + return FP16Vec8(_mm_add_ph(reg, b.reg)); + } + + FP16Vec8 operator-(const FP16Vec8 &b) const { + return FP16Vec8(_mm_sub_ph(reg, b.reg)); + } + + FP16Vec8 operator/(const FP16Vec8 &b) const { + return FP16Vec8(_mm_div_ph(reg, b.reg)); + } + + void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } +}; +#endif + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128i reg; + + explicit BF16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + __m256i reg; + + explicit BF16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } +}; + +#ifdef __AVX512F__ +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m512i reg; + + explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} + + explicit BF16Vec32(__m512i data) : reg(data) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg((__m512i)_mm512_inserti32x4( + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( + (__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1), + (__m128i)vec8_data.reg, 2), + (__m128i)vec8_data.reg, 3)) {} + + void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } +}; +#else +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m256i reg_low; + __m256i reg_high; + + explicit BF16Vec32(const void *ptr) + : reg_low(_mm256_loadu_si256((__m256i const *)ptr)), + reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {} + + explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low), + reg_high(high) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg_low((__m256i)_mm256_inserti32x4( + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)), + reg_high((__m256i)_mm256_inserti32x4( + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)) {} + + void save(void *ptr) const { + *reinterpret_cast<__m256i *>(ptr) = reg_low; + *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high; + } +}; +#endif + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __m128 reg; + float values[VEC_ELEM_NUM]; + }; + + __m128 reg; + + explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} + + explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} + + explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} + + explicit FP32Vec4(__m128 data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + __m256 reg; + float values[VEC_ELEM_NUM]; + }; + + __m256 reg; + + explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} + + explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + + explicit FP32Vec8(__m256 data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} + +#ifdef __AVX512FP16__ + explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} +#endif + + explicit FP32Vec8(const BF16Vec8 &v) + : reg(_mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), + expf(ar.values[5]), expf(ar.values[4]), + expf(ar.values[3]), expf(ar.values[2]), + expf(ar.values[1]), expf(ar.values[0]))); + } + + FP32Vec8 tanh() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), + tanhf(ar.values[5]), tanhf(ar.values[4]), + tanhf(ar.values[3]), tanhf(ar.values[2]), + tanhf(ar.values[1]), tanhf(ar.values[0]))); + } + + FP32Vec8 er() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), + erf(ar.values[5]), erf(ar.values[4]), + erf(ar.values[3]), erf(ar.values[2]), + erf(ar.values[1]), erf(ar.values[0]))); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_mul_ps(reg, b.reg)); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_add_ps(reg, b.reg)); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_sub_ps(reg, b.reg)); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_div_ps(reg, b.reg)); + } + + void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } +}; + +#ifdef __AVX512F__ +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512 reg; + float values[VEC_ELEM_NUM]; + }; + + __m512 reg; + + explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} + + explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} + + explicit FP32Vec16(__m512 data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg((__m512)_mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), + (__m128i)data.reg, 1), + (__m128i)data.reg, 2), + (__m128i)data.reg, 3)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg((__m512)_mm512_inserti32x8( + _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} + + explicit FP32Vec16(const BF16Vec16 &v) + : reg(_mm512_castsi512_ps( + _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_mul_ps(reg, b.reg)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_add_ps(reg, b.reg)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_sub_ps(reg, b.reg)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_div_ps(reg, b.reg)); + } + + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); + return _mm512_mask_reduce_add_ps(mask, reg); + } + + void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } +}; +#else +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + union AliasReg { + __m256 reg; + float values[8]; + }; + + __m256 reg_low; + __m256 reg_high; + + explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)), + reg_high(_mm256_set1_ps(v)) {} + + explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)), + reg_high(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)), + reg_high(_mm256_loadu_ps(ptr + 8)) {} + + explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} + + explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low), + reg_high(data.reg_high) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg_low((__m256)_mm256_inserti128_si256( + _mm256_castsi128_si256((__m128i)data.reg), + (__m128i)data.reg, 1)), + reg_high((__m256)_mm256_inserti128_si256( + _mm256_castsi128_si256((__m128i)data.reg), + (__m128i)data.reg, 1)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg_low(data.reg), reg_high(data.reg) {} + + explicit FP32Vec16(const BF16Vec16 &v) { + __m128i low = _mm256_extractf128_si256(v.reg, 0); + __m128i high = _mm256_extractf128_si256(v.reg, 1); + + __m256i v_low_epi32 = _mm256_cvtepu16_epi32(low); + __m256i v_high_epi32 = _mm256_cvtepu16_epi32(high); + + __m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2); + __m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2); + + reg_low = _mm256_castsi256_ps(v_low_shifted); + reg_high = _mm256_castsi256_ps(v_high_shifted); + } + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low), + _mm256_mul_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low), + _mm256_add_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low), + _mm256_sub_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low), + _mm256_div_ps(reg_high, b.reg_high)); + } + + float reduce_sum() const { + FP32Vec8 low = FP32Vec8(reg_low); + FP32Vec8 high = FP32Vec8(reg_high); + return low.reduce_sum() + high.reduce_sum(); + } + + template float reduce_sub_sum(int idx) { + float sum = 0.0; + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + uint32_t mask = base_mask << (idx * group_size); + + AliasReg ar; + + auto func = [&sum, &mask, &ar](int i) { + int flag = mask & 0x1; + mask = mask >> 1; + if (flag != 0) sum += ar.values[i]; + }; + + ar.reg = reg_low; + unroll_loop(func); + + ar.reg = reg_high; + unroll_loop(func); + + return sum; + } + + void save(float *ptr) const { + _mm256_storeu_ps(ptr, reg_low); + _mm256_storeu_ps(ptr + 8, reg_high); + } +}; +#endif + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +#ifdef __AVX512FP16__ +template <> struct VecType { using vec_type = FP16Vec16; }; +#endif + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +#ifdef __AVX512FP16__ +template <> inline void storeFP32(float v, c10::Half *ptr) { + *reinterpret_cast<_Float16 *>(ptr) = v; +} +#endif + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { + acc = acc + a * b; +} + +#ifdef __AVX512BF16__ +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} + +inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { + acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); +} +#else +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +#ifdef __AVX512F__ +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(_mm256_cvtepi32_epi16( + _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg(_mm512_cvtepi32_epi16( + _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} +#else +namespace{ +__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { + __m256i ai = _mm256_castps_si256(a); + ai = _mm256_srli_epi32(ai, 16); + ai = _mm256_packus_epi32(ai, ai); + ai = _mm256_permute4x64_epi64(ai, 0b00111001); + return _mm256_extracti128_si256(ai, 0); +} +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { + BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low)); + BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high)); + reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1); +} +#endif // __AVX512F__ +#endif // __AVX512BF16__ + +inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } + +}; // namespace vec_op + +#endif diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp new file mode 100644 index 0000000..a76ad08 --- /dev/null +++ b/csrc/cpu/layernorm.cpp @@ -0,0 +1,117 @@ +#include "cpu_types.hpp" + +namespace { +template +void rms_norm_impl(scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + vec_op::FP32Vec8 variance(0.0); + auto input_p = input + i * hidden_size; + auto output_p = out + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + vec_op::FP32Vec8 fp32_x(x); + variance = variance + fp32_x * fp32_x; + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + vec_op::FP32Vec8 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + scalar_vec_t w(weight + j); + + vec_op::FP32Vec8 fp32_x(x); + vec_op::FP32Vec8 fp32_w(w); + + vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out); + out.save(output_p + j); + } + } +} + +template +void fused_add_rms_norm_impl(scalar_t* __restrict__ input, + scalar_t* __restrict__ residual, + const scalar_t* __restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + vec_op::FP32Vec8 variance(0.0); + auto input_p = input + i * hidden_size; + auto residual_p = residual + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + scalar_vec_t res(residual_p + j); + vec_op::FP32Vec8 fp32_x(x); + vec_op::FP32Vec8 fp32_res(res); + + fp32_x = fp32_x + fp32_res; + variance = variance + fp32_x * fp32_x; + scalar_vec_t out(fp32_x); + out.save(residual_p + j); + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + vec_op::FP32Vec8 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t w(weight + j); + scalar_vec_t res(residual_p + j); + + vec_op::FP32Vec8 fp32_w(w); + vec_op::FP32Vec8 fp32_res(res); + + vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out); + out.save(input_p + j); + } + } +} +} // namespace + +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { + CPU_KERNEL_GUARD_IN(rms_norm_impl) + rms_norm_impl(out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, + hidden_size); + CPU_KERNEL_GUARD_OUT(rms_norm_impl) + }); +} + +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "fused_add_rms_norm_impl", [&] { + CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl) + fused_add_rms_norm_impl( + input.data_ptr(), residual.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl) + }); +} diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp new file mode 100644 index 0000000..96bce7d --- /dev/null +++ b/csrc/cpu/pos_encoding.cpp @@ -0,0 +1,199 @@ + +#include "cpu_types.hpp" + +namespace { +template +void rotary_embedding_impl( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int num_tokens) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + + const int embed_dim = rot_dim / 2; + bool flag = (embed_dim % VEC_ELEM_NUM == 0); + const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM; + + auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr, + scalar_t* qk) { + int j = 0; + for (; j < loop_upper; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); + + const scalar_vec_t q_x(qk + out_x); + const scalar_vec_t q_y(qk + out_y); + + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); + + vec_op::FP32Vec8 fp32_q_x(q_x); + vec_op::FP32Vec8 fp32_q_y(q_y); + + auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + scalar_vec_t(out1).save(qk + out_x); + + auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + scalar_vec_t(out2).save(qk + out_y); + } + if (!flag) { + for (; j < embed_dim; ++j) { + const int x_index = j; + const int y_index = embed_dim + j; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const float fp32_cos = cache_ptr[x_index]; + const float fp32_sin = cache_ptr[y_index]; + + const float fp32_q_x = qk[out_x]; + const float fp32_q_y = qk[out_y]; + + qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + } + } + }; + +#pragma omp parallel for + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + for (int i = 0; i < num_heads; ++i) { + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; + compute_loop(token_head, cache_ptr, query); + } + + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + compute_loop(token_head, cache_ptr, key); + } + } +} + +template +void rotary_embedding_gptj_impl( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int num_tokens) { + const int embed_dim = rot_dim / 2; + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int i = 0; i < num_heads; ++i) { + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; + scalar_t* head_query = token_head + query; + for (int j = 0; j < embed_dim; j += 1) { + const int rot_offset = j; + const int x_index = 2 * rot_offset; + const int y_index = 2 * rot_offset + 1; + + const float cos = cos_cache_ptr[rot_offset]; + const float sin = sin_cache_ptr[rot_offset]; + + const float x = head_query[x_index]; + const float y = head_query[y_index]; + + head_query[x_index] = x * cos - y * sin; + head_query[y_index] = y * cos + x * sin; + } + } + } + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int i = 0; i < num_kv_heads; ++i) { + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + scalar_t* head_key = key + token_head; + for (int j = 0; j < embed_dim; j += 1) { + const int rot_offset = j; + const int x_index = 2 * rot_offset; + const int y_index = 2 * rot_offset + 1; + + const float cos = cos_cache_ptr[rot_offset]; + const float sin = sin_cache_ptr[rot_offset]; + + const float x = head_key[x_index]; + const float y = head_key[y_index]; + + head_key[x_index] = x * cos - y * sin; + head_key[y_index] = y * cos + x * sin; + } + } + } +} +}; // namespace + +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int64_t head_size, + torch::Tensor& cos_sin_cache, bool is_neox) { + int num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t key_stride = key.stride(-2); + int64_t query_stride = query.stride(-2); + + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "rotary_embedding_impl", [&] { + CPU_KERNEL_GUARD_IN(rotary_embedding_impl) + if (is_neox) { + rotary_embedding_impl( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size, num_tokens); + } else { + rotary_embedding_gptj_impl( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size, num_tokens); + } + + CPU_KERNEL_GUARD_OUT(rotary_embedding_impl) + }); +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp new file mode 100644 index 0000000..cf7d977 --- /dev/null +++ b/csrc/cpu/torch_bindings.cpp @@ -0,0 +1,117 @@ +#include "cache.h" +#include "ops.h" +#include "core/registration.h" + +#include + +void init_cpu_threads_env(const std::string& cpu_ids); + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + // vLLM custom ops + + // Attention ops + // Compute the attention between an input query and the cached keys/values + // using PagedAttention. + ops.def( + "paged_attention_v1(" + " Tensor! out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); + + // PagedAttention V2. + ops.def( + "paged_attention_v2(" + " Tensor! out, Tensor exp_sums, Tensor max_logits," + " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); + + // Activation ops + + // Activation function used in SwiGLU. + ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul); + + // Activation function used in GeGLU with `none` approximation. + ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul); + + // Activation function used in GeGLU with `tanh` approximation. + ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul); + + // GELU implementation used in GPT-2. + ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_new", torch::kCPU, &gelu_new); + + // Approximate GELU implementation. + ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_fast", torch::kCPU, &gelu_fast); + + // Quick GELU implementation. + ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_quick", torch::kCPU, &gelu_quick); + + // Layernorm + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def( + "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " + "()"); + ops.impl("rms_norm", torch::kCPU, &rms_norm); + + // In-place fused Add and RMS Normalization. + ops.def( + "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " + "float epsilon) -> ()"); + ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm); + + // Rotary embedding + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. + ops.def( + "rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); +} + +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { + // Cache ops + // Swap in (out) the cache blocks from src to dst. + cache_ops.def( + "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); + cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks); + + // Copy the cache blocks from src to dst. + cache_ops.def( + "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " + "block_mapping) -> ()"); + cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks); + + // Reshape the key and value tensors and cache them. + cache_ops.def( + "reshape_and_cache(Tensor key, Tensor value," + " Tensor! key_cache, Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); + cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); +} + +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { + // CPU utils + utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp new file mode 100644 index 0000000..5782580 --- /dev/null +++ b/csrc/cpu/utils.cpp @@ -0,0 +1,65 @@ +#include +#include +#include +#include + +#include "cpu_types.hpp" + +void init_cpu_threads_env(const std::string& cpu_ids) { + bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); + TORCH_CHECK(omp_cpu_mask->size > 0); + std::vector omp_cpu_ids; + omp_cpu_ids.reserve(omp_cpu_mask->size); + + constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp); + + for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) { + unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size]; + int i = 0; + while (group_mask) { + if (group_mask & 1) { + omp_cpu_ids.emplace_back(offset + i); + } + ++i; + group_mask >>= 1; + } + } + + // Memory node binding + if (numa_available() != -1) { + int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front()); + bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str()); + bitmask* src_mask = numa_get_membind(); + + int pid = getpid(); + + // move all existing pages to the specified numa node. + *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); + int page_num = numa_migrate_pages(pid, src_mask, mask); + if (page_num == -1) { + TORCH_CHECK(false, + "numa_migrate_pages failed. errno: " + std::to_string(errno)); + } + + // restrict memory allocation node. + numa_set_membind(mask); + numa_set_strict(1); + } + + // OMP threads binding + omp_set_num_threads((int)omp_cpu_ids.size()); + torch::set_num_threads((int)omp_cpu_ids.size()); + TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads()); + TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads()); +#pragma omp parallel for schedule(static, 1) + for (size_t i = 0; i < omp_cpu_ids.size(); ++i) { + cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size); + size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size); + CPU_ZERO_S(size, mask); + CPU_SET_S(omp_cpu_ids[i], size, mask); + sched_setaffinity(0, sizeof(cpu_set_t), mask); + CPU_FREE(mask); + } + + numa_free_nodemask(omp_cpu_mask); +} diff --git a/csrc/paged_attention/cuda_compat.h b/csrc/cuda_compat.h similarity index 52% rename from csrc/paged_attention/cuda_compat.h rename to csrc/cuda_compat.h index 6408fcd..82e5561 100644 --- a/csrc/paged_attention/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -1,5 +1,15 @@ #pragma once +#ifdef USE_ROCM + #include +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + #ifndef USE_ROCM #define VLLM_LDG(arg) __ldg(arg) #else @@ -7,9 +17,14 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) #endif #ifndef USE_ROCM @@ -18,10 +33,17 @@ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #endif +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + #ifndef USE_ROCM #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) #else #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) -#endif \ No newline at end of file +#endif diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h new file mode 100644 index 0000000..c352242 --- /dev/null +++ b/csrc/cuda_utils.h @@ -0,0 +1,15 @@ +#pragma once + +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) + #define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ + #define DEVICE_INLINE __forceinline__ __device__ + #define HOST_INLINE __forceinline__ __host__ +#else + #define HOST_DEVICE_INLINE inline + #define DEVICE_INLINE inline + #define HOST_INLINE inline +#endif + +int64_t get_device_attribute(int64_t attribute, int64_t device_id); + +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu new file mode 100644 index 0000000..d6f9eb6 --- /dev/null +++ b/csrc/cuda_utils_kernels.cu @@ -0,0 +1,29 @@ +#ifdef USE_ROCM + #include + #include +#endif +int64_t get_device_attribute(int64_t attribute, int64_t device_id) { + int device, value; + if (device_id < 0) { + cudaGetDevice(&device); + } else { + device = device_id; + } + cudaDeviceGetAttribute(&value, static_cast(attribute), + device); + return value; +} + +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) { + int64_t attribute; + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html + // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 + +#ifdef USE_ROCM + attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; +#else + attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; +#endif + + return get_device_attribute(attribute, device_id); +} diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu new file mode 100644 index 0000000..82a3563 --- /dev/null +++ b/csrc/custom_all_reduce.cu @@ -0,0 +1,153 @@ +#include +#include +#include +#include + +#include "custom_all_reduce.cuh" + +// fake pointer type, must match fptr_t type in ops.h +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int64_t rank, + bool full_nvlink) { + int world_size = offsets.size(); + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (world_size != handles.size()) + throw std::invalid_argument( + "handles length should equal to offsets length"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + + cudaIpcMemHandle_t ipc_handles[8]; + for (int i = 0; i < world_size; i++) { + std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); + } + return (fptr_t) new vllm::CustomAllreduce( + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); +} + +/** + * Make sure tensor t's data lies completely within ((char)t.data_ptr()) + + * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous() + * because it allows transpose of contiguous slice (i.e. slicing the first + * dimension). Currently, we require this because stride information is not + * passed into the kernels and we treat input tensors as flat. + * + * Examples + * A = torch.zeros(3, 3, 3) + * 1. A: OK + * 2. A[1:]: OK + * 3. A.permute(2, 0, 1): OK + * 4. A[1:].permute(2, 0, 1): OK + * 5. A[None].expand(2, -1, -1, -1): Not OK + * 6. A[:, 1:, 1:]: Not OK + */ +bool _is_weak_contiguous(torch::Tensor& t) { + return t.is_contiguous() || + (t.storage().nbytes() - t.storage_offset() * t.element_size() == + t.numel() * t.element_size()); +} + +bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, + bool full_nvlink) { + auto inp_size = inp.numel() * inp.element_size(); + // custom allreduce requires input byte size to be multiples of 16 + if (inp_size % 16 != 0) return false; + if (!_is_weak_contiguous(inp)) return false; + if (world_size == 2 || full_nvlink) return inp_size <= max_size; + // for 4 or more non NVLink-capable GPUs, custom allreduce provides little + // performance improvement over NCCL. + return false; +} + +void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + cudaStream_t stream) { + auto fa = reinterpret_cast(_fa); + TORCH_CHECK(_is_weak_contiguous(out)); + switch (out.scalar_type()) { + case at::ScalarType::Float: { + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel()); + break; + } + case at::ScalarType::Half: { + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); + break; + } +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: { + fa->allreduce( + stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); + break; + } +#endif + default: + throw std::runtime_error( + "custom allreduce only supports float32, float16 and bfloat16"); + } +} + +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + _all_reduce(_fa, inp, out, stream); +} + +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + + auto input_size = inp.numel() * inp.element_size(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), + "registered buffer is too small to contain the input"); + AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), + input_size, cudaMemcpyDeviceToDevice, stream)); + _all_reduce(_fa, reg_buffer, out, stream); +} + +void dispose(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + delete fa; +} + +int64_t meta_size() { return sizeof(vllm::Signal); } + +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets) { + auto fa = reinterpret_cast(_fa); + fa->register_buffer(handles, offsets, t.data_ptr()); +} + +std::tuple> get_graph_buffer_ipc_meta( + fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto handles = + torch::empty({static_cast(handle_bytes.size())}, options); + std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); + return {handles, std::move(offsets)}; +} + +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets) { + auto fa = reinterpret_cast(_fa); + fa->register_graph_buffers(handles, offsets); +} diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh new file mode 100644 index 0000000..1ed49b8 --- /dev/null +++ b/csrc/custom_all_reduce.cuh @@ -0,0 +1,482 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace vllm { + +constexpr int kMaxBlocks = 64; +// note: we don't want to use atomics for signals because peer atomics are no +// supported on PCIe links +struct Signal { + alignas(128) uint32_t start[kMaxBlocks][8]; + alignas(128) uint32_t end[kMaxBlocks][8]; +}; + +struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; + +struct __align__(16) RankSignals { volatile Signal* signals[8]; }; + +// like std::array, but aligned +template +struct __align__(alignof(T) * sz) array_t { + T data[sz]; + using type = T; + static constexpr int size = sz; +}; + +// use packed type to maximize memory efficiency +// goal: generate ld.128 and st.128 instructions +template +struct packed_t { + // the (P)acked type for load/store + using P = array_t; + // the (A)ccumulator type for reduction + using A = array_t; +}; + +#define DINLINE __device__ __forceinline__ + +// scalar cast functions +DINLINE float upcast_s(half val) { return __half2float(val); } + +template +DINLINE T downcast_s(float val); +template <> +DINLINE half downcast_s(float val) { + return __float2half(val); +} + +// scalar add functions +// for some reason when compiling with Pytorch, the + operator for half and +// bfloat is disabled so we call the intrinsics directly +DINLINE half& assign_add(half& a, half b) { + a = __hadd(a, b); + return a; +} +DINLINE float& assign_add(float& a, float b) { return a += b; } + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } +template <> +DINLINE nv_bfloat16 downcast_s(float val) { + return __float2bfloat16(val); +} +DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { + a = __hadd(a, b); + return a; +} +#endif + +template +DINLINE array_t& packed_assign_add(array_t& a, array_t b) { +#pragma unroll + for (int i = 0; i < N; i++) { + assign_add(a.data[i], b.data[i]); + } + return a; +} + +template +DINLINE array_t upcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + array_t out; +#pragma unroll + for (int i = 0; i < N; i++) { + out.data[i] = upcast_s(val.data[i]); + } + return out; + } +} + +template +DINLINE O downcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + O out; +#pragma unroll + for (int i = 0; i < O::size; i++) { + out.data[i] = downcast_s(val.data[i]); + } + return out; + } +} + +// This function is meant to be used as the first synchronization in the all +// reduce kernel. Thus, it doesn't need to make any visibility guarantees for +// prior memory accesses. Note: volatile writes will not be reordered against +// other volatile writes. +template +DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, + int rank) { + if (threadIdx.x < ngpus) { + // reset flag for next time + self_sg->end[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->start[blockIdx.x][threadIdx.x]); + } + __syncthreads(); +} + +// This function is meant to be used as the second or the final synchronization +// barrier in the all reduce kernel. If it's the final synchronization barrier, +// we don't need to make any visibility guarantees for prior memory accesses. +template +DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, + int rank) { + __syncthreads(); + // eliminate the case that prior writes are not visible after signals become + // visible. Note that I did not managed to make this happen through a lot of + // testing. Might be the case that hardware provides stronger guarantee than + // the memory model. + if constexpr (!final_sync) __threadfence_system(); + if (threadIdx.x < ngpus) { + // reset flag for next time + self_sg->start[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->end[blockIdx.x][threadIdx.x]); + } + if constexpr (!final_sync) __syncthreads(); +} + +template +DINLINE P packed_reduce(const P* ptrs[], int idx) { + A tmp = upcast(ptrs[0][idx]); +#pragma unroll + for (int i = 1; i < ngpus; i++) { + packed_assign_add(tmp, upcast(ptrs[i][idx])); + } + return downcast

(tmp); +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, + int rank, int size) { + using P = typename packed_t::P; + using A = typename packed_t::A; + // note: we don't reorder the address so the accumulation order is the same + // for all ranks, ensuring bitwise identical results + auto dp = *_dp; + start_sync(sg, self_sg, rank); + // do the actual reduction + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); + } + end_sync(sg, self_sg, rank); +} + +template +DINLINE P* get_tmp_buf(volatile Signal* sg) { + return (P*)(((Signal*)sg) + 1); +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, + int rank, int size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; + using A = typename packed_t::A; + int part = size / ngpus; + int start = rank * part; + int end = rank == ngpus - 1 ? size : start + part; + int largest_part = part + size % ngpus; + const P* ptrs[ngpus]; + P* tmps[ngpus]; +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int target = (rank + i) % ngpus; + ptrs[i] = (const P*)_dp->ptrs[target]; + tmps[i] = get_tmp_buf

(sg.signals[target]); + } + auto tmp_out = tmps[0]; + start_sync(sg, self_sg, rank); + // stage 1: reduce scatter + for (int idx = start + tid; idx < end; idx += stride) { + tmp_out[idx - start] = packed_reduce(ptrs, idx); + } + end_sync(sg, self_sg, rank); + + // stage 2: allgather. Note: it's important to match the tid between + // the two stages, because visibility across devices is only guaranteed + // between threads that have the same tid. If thread i computes the sum of + // start + i in the first stage, then thread i also gathers start + i from all + // ranks. + for (int idx = tid; idx < largest_part; idx += stride) { +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int gather_from_rank = ((rank + i) % ngpus); + if (gather_from_rank == ngpus - 1 || idx < part) { + int dst_idx = gather_from_rank * part + idx; + ((P*)result)[dst_idx] = tmps[i][idx]; + } + } + } +} + +using IPC_KEY = std::array; +static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t)); +static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t)); + +class CustomAllreduce { + public: + int rank_; + int world_size_; + bool full_nvlink_; + + // below are device pointers + RankSignals sg_; + std::unordered_map buffers_; + Signal* self_sg_; + + // stores the registered device pointers from all ranks + RankData *d_rank_data_base_, *d_rank_data_end_; + std::vector graph_unreg_buffers_; + // a map from IPC handles to opened IPC pointers + std::map ipc_handles_; + + /** + * meta is a pointer to device metadata and temporary buffer for allreduce. + * + * There's a total of sizeof(Signal) of prefix before the actual data, + * so meta + 1 points to actual temporary buffer. + * + * note: this class does not own any device memory. Any required buffers + * are passed in from the constructor + */ + CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t* handles, + const std::vector& offsets, int rank, + bool full_nvlink = true) + : rank_(rank), + world_size_(offsets.size()), + full_nvlink_(full_nvlink), + self_sg_(meta), + d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { + for (int i = 0; i < world_size_; i++) { + Signal* rank_sg; + if (i != rank_) { + char* handle = open_ipc_handle(&handles[i]); + handle += offsets[i]; + rank_sg = (Signal*)handle; + } else { + rank_sg = self_sg_; + } + sg_.signals[i] = rank_sg; + } + } + + char* open_ipc_handle(const void* ipc_handle) { + auto [it, new_handle] = + ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + if (new_handle) { + char* ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)ipc_handle), + cudaIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; + } + + std::pair, std::vector> + get_graph_buffer_ipc_meta() { + auto num_buffers = graph_unreg_buffers_.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); + std::vector handles(handle_sz * num_buffers, 0); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = graph_unreg_buffers_[i]; + void* base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (cuPointerGetAttribute(&base_ptr, + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + (CUdeviceptr)ptr) != CUDA_SUCCESS) + throw std::runtime_error("failed to get pointer attr"); + CUDACHECK(cudaIpcGetMemHandle( + (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); + } + return std::make_pair(handles, offsets); + } + + void check_rank_data_capacity(size_t num = 1) { + if (d_rank_data_base_ + num > d_rank_data_end_) + throw std::runtime_error( + "Rank data buffer is overflowed by " + + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + } + + void register_buffer(const std::vector& handles, + const std::vector& offsets, void* self) { + check_rank_data_capacity(); + RankData data; + for (int i = 0; i < world_size_; i++) { + if (i != rank_) { + char* handle = open_ipc_handle(handles[i].data()); + handle += offsets[i]; + data.ptrs[i] = handle; + } else { + data.ptrs[i] = self; + } + } + auto d_data = d_rank_data_base_++; + CUDACHECK( + cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); + buffers_[self] = d_data; + } + + // note: when registering graph buffers, we intentionally choose to not + // deduplicate the addresses. That means if the allocator reuses some + // addresses, they will be registered again. This is to account for the remote + // possibility of different allocation patterns between ranks. For example, + // rank 1 may get the same input address for the second allreduce, but rank 2 + // got a different address. IPC handles have internal reference counting + // mechanism so overhead should be small. + void register_graph_buffers( + const std::vector& handles, + const std::vector>& offsets) { + auto num_buffers = graph_unreg_buffers_.size(); + check_rank_data_capacity(num_buffers); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = graph_unreg_buffers_[i]; + auto& rd = rank_data[i]; + for (int j = 0; j < world_size_; j++) { + if (j != rank_) { + char* handle = + open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + } + CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), + sizeof(RankData) * num_buffers, + cudaMemcpyHostToDevice)); + d_rank_data_base_ += num_buffers; + graph_unreg_buffers_.clear(); + } + + /** + * This is the result after careful grid search. Using 36 blocks give the best + * or close to the best runtime on the devices I tried: A100, A10, A30, T4, + * V100. You'll notice that NCCL kernels also only take a small amount of SMs. + * Not quite sure the underlying reason, but my guess is that too many SMs + * will cause contention on NVLink bus. + */ + template + void allreduce(cudaStream_t stream, T* input, T* output, int size, + int threads = 512, int block_limit = 36) { + auto d = packed_t::P::size; + if (size % d != 0) + throw std::runtime_error( + "custom allreduce currently requires input length to be multiple " + "of " + + std::to_string(d)); + if (block_limit > kMaxBlocks) + throw std::runtime_error("max supported block limit is " + + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); + + RankData* ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error( + "buffer address " + + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } + + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); +#define KL(ngpus, name) \ + name<<>>(ptrs, sg_, self_sg_, output, \ + rank_, size); +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (full_nvlink_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || \ + (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } \ + break; \ + } + + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } +#undef REDUCE_CASE +#undef KL + } + + ~CustomAllreduce() { + for (auto [_, ptr] : ipc_handles_) { + CUDACHECK(cudaIpcCloseMemHandle(ptr)); + } + } +}; +/** + * To inspect PTX/SASS, copy paste this header file to compiler explorer and add + a template instantiation: + * template void vllm::CustomAllreduce::allreduce(cudaStream_t, half *, + half *, int, int, int); +*/ +} // namespace vllm diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu new file mode 100644 index 0000000..f786823 --- /dev/null +++ b/csrc/custom_all_reduce_test.cu @@ -0,0 +1,316 @@ +/** + * This is a standalone test for custom allreduce. + * To compile, make sure you have MPI and NCCL installed in your system. + * export MPI_HOME=XXX + * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o + * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi + * + * Warning: this C++ test is not designed to be very readable and was used + * during the rapid prototyping process. + * + * To run: + * mpirun -np 8 ./custom_all_reduce_test + */ +#include +#include +#include +#include + +#include +#include + +#include "cuda_profiler_api.h" +#include "custom_all_reduce.cuh" +#include "mpi.h" +#include "nccl.h" + +#define MPICHECK(cmd) \ + do { \ + int e = cmd; \ + if (e != MPI_SUCCESS) { \ + printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define NCCLCHECK(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \ + ncclGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +__global__ void dummy_kernel() { + for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms +} + +template +__global__ void set_data(T* data, int size, int myRank) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + data[idx] = myRank * 0.11f; + } +} + +template +__global__ void convert_data(const T* data1, const T* data2, double* fdata1, + double* fdata2, int size) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + fdata1[idx] = data1[idx]; + fdata2[idx] = data2[idx]; + } +} + +__global__ void init_rand(curandState_t* state, int size, int nRanks) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + for (int i = 0; i < nRanks; i++) { + curand_init(i + 1, idx, 0, &state[idx * nRanks + i]); + } + } +} + +template +__global__ void gen_data(curandState_t* state, T* data, double* ground_truth, + int myRank, int nRanks, int size) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + double sum = 0.0; + for (int i = 0; i < nRanks; i++) { + double val = curand_uniform_double(&state[idx * nRanks + i]) * 4; + T hval = val; // downcast first + sum += static_cast(hval); + if (i == myRank) data[idx] = hval; + } + ground_truth[idx] = sum; + } +} + +template +void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, + int data_size, bool performance_test) { + T* result; + cudaStream_t stream; + CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); + CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T))); + + cudaIpcMemHandle_t self_data_handle; + cudaIpcMemHandle_t data_handles[8]; + vllm::Signal* buffer; + T* self_data_copy; + /** + * Allocate IPC buffer + * + * The first section is a temporary buffer for storing intermediate allreduce + * results, if a particular algorithm requires it. The second section is for + * the input to the allreduce. The actual API takes the input pointer as an + * argument (that is, they can and usually should be allocated separately). + * But since the input pointers and the temporary buffer all require IPC + * registration, they are allocated and registered together in the test for + * convenience. + */ + CUDACHECK( + cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); + CUDACHECK( + cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); + CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T))); + CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer)); + + MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t), + MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), + MPI_BYTE, MPI_COMM_WORLD)); + + void* rank_data; + size_t rank_data_sz = 16 * 1024 * 1024; + CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); + std::vector offsets(nRanks, 0); + vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, + offsets, myRank); + auto* self_data = + reinterpret_cast(reinterpret_cast(buffer) + + sizeof(vllm::Signal) + data_size * sizeof(T)); + // hack buffer registration + { + std::vector handles; + handles.reserve(nRanks); + for (int i = 0; i < nRanks; i++) { + char* begin = (char*)&data_handles[i]; + char* end = (char*)&data_handles[i + 1]; + handles.emplace_back(begin, end); + } + std::vector offsets(nRanks, + sizeof(vllm::Signal) + data_size * sizeof(T)); + fa.register_buffer(handles, offsets, self_data); + } + + double* ground_truth; + CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); + curandState_t* states; + CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); + init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); + gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, + nRanks, data_size); + CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + cudaEvent_t start, stop; + CUDACHECK(cudaEventCreate(&start)); + CUDACHECK(cudaEventCreate(&stop)); + + ncclDataType_t ncclDtype; + if (std::is_same::value) { + ncclDtype = ncclFloat16; + } else if (std::is_same::value) { + ncclDtype = ncclBfloat16; + } else { + ncclDtype = ncclFloat; + } + double *nccl_result, *my_result; + CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double))); + CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double))); + if (performance_test) { + dummy_kernel<<<1, 1, 0, stream>>>(); + constexpr int warmup_iters = 5; + constexpr int num_iters = 100; + // warmup + for (int i = 0; i < warmup_iters; i++) { + NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, + comm, stream)); + } + CUDACHECK(cudaEventRecord(start, stream)); + for (int i = 0; i < num_iters; i++) { + NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, + comm, stream)); + } + CUDACHECK(cudaEventRecord(stop, stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + float allreduce_ms = 0; + cudaEventElapsedTime(&allreduce_ms, start, stop); + + dummy_kernel<<<1, 1, 0, stream>>>(); + // warm up + for (int i = 0; i < warmup_iters; i++) { + fa.allreduce(stream, self_data, result, data_size, threads, + block_limit); + } + CUDACHECK(cudaEventRecord(start, stream)); + for (int i = 0; i < num_iters; i++) { + fa.allreduce(stream, self_data, result, data_size, threads, + block_limit); + } + CUDACHECK(cudaEventRecord(stop, stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + + float duration_ms = 0; + cudaEventElapsedTime(&duration_ms, start, stop); + if (myRank == 0) + printf( + "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl " + "time:%.2fus\n", + myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit, + duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters); + + // And wait for all the queued up work to complete + CUDACHECK(cudaStreamSynchronize(stream)); + + NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype, + ncclSum, comm, stream)); + + convert_data<<<108, 1024, 0, stream>>>(self_data, result, nccl_result, + my_result, data_size); + CUDACHECK(cudaStreamSynchronize(stream)); + + for (unsigned long j = 0; j < data_size; j++) { + auto diff = abs(nccl_result[j] - my_result[j]); + if (diff >= 4e-2) { + printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n", + myRank, j, nccl_result[j], my_result[j], ground_truth[j]); + break; + } + } + long double nccl_diffs = 0.0; + long double my_diffs = 0.0; + for (int j = 0; j < data_size; j++) { + nccl_diffs += abs(nccl_result[j] - ground_truth[j]); + my_diffs += abs(my_result[j] - ground_truth[j]); + } + if (myRank == 0) + std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size + << " me: " << my_diffs / data_size << std::endl; + } else { + for (int i = 0; i < 100; i++) { + fa.allreduce(stream, self_data, result, data_size, threads, + block_limit); + CUDACHECK(cudaStreamSynchronize(stream)); + NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype, + ncclSum, comm, stream)); + convert_data<<<108, 1024, 0, stream>>>( + self_data_copy, result, nccl_result, my_result, data_size); + CUDACHECK(cudaStreamSynchronize(stream)); + + for (unsigned long j = 0; j < data_size; j++) { + auto diff = abs(nccl_result[j] - my_result[j]); + if (diff >= 4e-2) { + printf( + "Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n", + myRank, j, nccl_result[j], my_result[j], ground_truth[j]); + break; + } + } + } + if (myRank == 0) + printf("Test passed: nGPUs:%d, sz (kb): %d, %d, %d\n", nRanks, + data_size * sizeof(T) / 1024, threads, block_limit); + // long double nccl_diffs = 0.0; + // long double my_diffs = 0.0; + // for (int j = 0; j < data_size; j++) { + // nccl_diffs += abs(nccl_result[j] - ground_truth[j]); + // my_diffs += abs(my_result[j] - ground_truth[j]); + // } + // if (myRank == 0) + // std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size + // << " me: " << my_diffs / data_size << std::endl; + } + + CUDACHECK(cudaFree(result)); + CUDACHECK(cudaFree(self_data_copy)); + CUDACHECK(cudaFree(rank_data)); + CUDACHECK(cudaFree(buffer)); + CUDACHECK(cudaFree(states)); + CUDACHECK(cudaFreeHost(ground_truth)); + CUDACHECK(cudaFreeHost(nccl_result)); + CUDACHECK(cudaFreeHost(my_result)); + CUDACHECK(cudaStreamDestroy(stream)); +} + +int main(int argc, char** argv) { + int nRanks, myRank; + MPICHECK(MPI_Init(&argc, &argv)); + MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); + MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks)); + CUDACHECK(cudaSetDevice(myRank)); + ncclUniqueId id; + ncclComm_t comm; + if (myRank == 0) ncclGetUniqueId(&id); + MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, + MPI_COMM_WORLD)); + NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); + + bool performance_test = true; + cudaProfilerStart(); + // for (int threads : {256, 512}) { + // for (int block_limit = 16; block_limit < 112; block_limit += 4) { + // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); + // } + // } + for (int sz = 512; sz <= (8 << 20); sz *= 2) { + run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); + } + + cudaProfilerStop(); + return EXIT_SUCCESS; +} diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh new file mode 100644 index 0000000..1842fab --- /dev/null +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +namespace cute { + +//////////////////////////////////////////////////////////////////// +// layout utils +//////////////////////////////////////////////////////////////////// + +// Permute layout based on indices, example: +// permute_layout<1, 0>(layout) will swap the two dimensions +// permute_layout<0, 2, 1>(layout) will swap the last two dimensions +template +CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { + static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch"); + return cute::make_layout(cute::get(l)...); +} + +// is the layout f(x) = x +template +CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { + if constexpr (std::is_same_v) + return true; + else { + constexpr auto coalesced_layout = coalesce(Layout{}); + if constexpr (rank(coalesced_layout) == 1 && + stride<0>(coalesced_layout) == 1) { + return true; + } + return false; + } +} + +//////////////////////////////////////////////////////////////////// +// Pointer utils +//////////////////////////////////////////////////////////////////// + +template +static constexpr auto get_logical_ptr(PointerType* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return cute::subbyte_iterator(ptr); + } else { + return ptr; + } +} + +//////////////////////////////////////////////////////////////////// +// Misc utils +//////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() { + constexpr auto bits = sizeof_bits_v * Elements{}; + if constexpr (bits % 128 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } else if constexpr (bits % 64 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<64>{}; + } else if constexpr (bits % 32 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<32>{}; + } else if constexpr (bits % 16 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<16>{}; + } else { + return AutoVectorizingCopyWithAssumedAlignment<8>{}; + } +} + +}; // namespace cute diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp new file mode 100644 index 0000000..1618a34 --- /dev/null +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -0,0 +1,154 @@ +#pragma once + +#include + +#include "cute/layout.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +using ColumnMajor = typename cutlass::layout::ColumnMajor; +using RowMajor = typename cutlass::layout::RowMajor; + +namespace cute { + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, + seq) { + return g(f(cute::get(static_cast(t)), I)...); +} + +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { + return make_shape(f(I)...); +} + +}; // namespace detail + +template +CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { + if constexpr (cute::is_tuple::value) { + return detail::tapply_with_idx( + t, f, [](auto const&... a) { return cute::make_tuple(a...); }, + tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// calls: make_shape(f(0), f(1), ..., f(N-1)) +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { + return detail::make_shape_from_idx(f, make_seq{}); +} + +}; // namespace cute + +// Make a layout from a tensor with `rank(Stride{})`, where the shape is the +// shape of the passed in tensor and the strides are of type `Stride` and +// contain the strides of the passed in tensor, checking that any static strides +// in `Stride{}` match the strides of the passed in tensor. +// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra +// strides are set to be 0 or 1. +template +static inline auto make_cute_layout(torch::Tensor const& tensor, + std::string_view name = "tensor") { + TORCH_CHECK(tensor.dim() <= rank(Stride{})); + auto stride = cute::transform_with_idx( + Stride{}, [&](auto const& stride_ele, auto const& idx) { + using StrideEle = std::decay_t; + + if (idx < tensor.dim()) { + if constexpr (cute::is_static_v) { + TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", + name, ".stride(", idx, ") to be ", StrideEle::value); + return StrideEle{}; + } else { + return tensor.stride(idx); + } + } else { + // Extra strides are assumed to be 0 or 1 + if constexpr (cute::is_static_v) { + static_assert(StrideEle::value == 0 || StrideEle::value == 1); + } + return StrideEle{}; + } + }); + + auto shape = cute::make_shape_from_idx([&](auto const& idx) { + if (idx < tensor.dim()) + return tensor.size(idx); + else + return int64_t(1); + }); + + return make_layout(shape, stride); +} + +template +static inline auto maybe_make_cute_layout( + c10::optional const& tensor, + std::string_view name = "tensor") { + using Layout = decltype(make_cute_layout(*tensor)); + + if (tensor) { + return std::optional{make_cute_layout(*tensor, name)}; + } else { + return std::optional{}; + } +} + +// +// Torch Type to Cutlass Type (equivalent_cutlass_type) +// + +template +struct equivalent_cutlass_type { + using type = T; +}; + +template +using equivalent_cutlass_type_t = typename equivalent_cutlass_type::type; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::half_t; +}; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::bfloat16_t; +}; + +// +// equivalent_scalar_t (basically inverse of equivalent_cutlass_type) +// + +// Return a `c10::CppTypeToScalarType` compatible type, i.e. get the C++ from +// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` +template +struct equivalent_scalar_type { + using type = T; +}; + +template +using equivalent_scalar_type_t = typename equivalent_scalar_type::type; + +template <> +struct equivalent_scalar_type { + using type = c10::Half; +}; + +template <> +struct equivalent_scalar_type { + using type = c10::BFloat16; +}; + +// get equivalent c10::ScalarType tag from compile time type +template +static inline constexpr c10::ScalarType equivalent_scalar_type_v = + c10::CppTypeToScalarType>::value; \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_collective_builder.cuh b/csrc/cutlass_extensions/vllm_collective_builder.cuh new file mode 100644 index 0000000..085ee12 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_collective_builder.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::gemm::collective { +using namespace cute; + +// +// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for +// for custom kernel tags, allowing you to build custom collectives. Without +// touching the cutlass library headers, using `CutlassKernelTag` will mean it +// will resort to using the standard cutlass collective builder. +// + +// Use the default Cutlass collective builder, i.e. use an unmodified cutless +// collective +struct CutlassKernelTag {}; + +template +struct VLLMCollectiveBuilder { + static_assert(sizeof(ElementA) == 0, + "Could not build a collective for given parameters."); +}; + +template +struct VLLMCollectiveBuilder< + CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, + ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType> { + using CollectiveOp = typename CollectiveBuilder< + ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB, + GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp; +}; + +}; // namespace cutlass::gemm::collective \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_custom_types.cuh b/csrc/cutlass_extensions/vllm_custom_types.cuh new file mode 100644 index 0000000..6146bdc --- /dev/null +++ b/csrc/cutlass_extensions/vllm_custom_types.cuh @@ -0,0 +1,50 @@ +#pragma once + +#include "cutlass/integer_subbyte.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct vllm_biased_integer_subbyte : public integer_subbyte { + using Base = integer_subbyte; + + using Storage = typename Base::Storage; + using xint_t = typename Base::xint_t; + + using Base::bits_mask_; + using Base::sign_mask_; + using Base::storage; + + // + // Methods + // + + /// No operation + vllm_biased_integer_subbyte() = default; + + /// Conversion from integer type + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value) + : Base(value) {} +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// "GPTQ" types, i.e. symmetric quantization +using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8 +using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128 + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct sizeof_bits> { + static constexpr int value = Bits; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py new file mode 100644 index 0000000..4fcfcd3 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -0,0 +1,49 @@ +import enum +from typing import Dict, Union + +from cutlass_library import * + +# +# Extend cutlass library with custom types, and missing values +# + + +class VLLMDataType(enum.Enum): + u4b8 = enum_auto() + u8b128 = enum_auto() + + +class MixedInputKernelScheduleType(enum.Enum): + TmaWarpSpecializedMixedInput = enum_auto() + TmaWarpSpecializedPingpongMixedInput = enum_auto() + TmaWarpSpecializedCooperativeMixedInput = enum_auto() + + +VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { + **DataTypeNames, # type: ignore + **{ + VLLMDataType.u4b8: "u4b8", + VLLMDataType.u8b128: "u8b128", + } +} + +VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + **DataTypeTag, # type: ignore + **{ + VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", + VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", + } +} + +VLLMKernelScheduleTag: Dict[Union[ + MixedInputKernelScheduleType, KernelScheduleType], str] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput", + } + } diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh new file mode 100644 index 0000000..2ad914f --- /dev/null +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -0,0 +1,795 @@ +#pragma once + +#include "cutlass/numeric_conversion.h" +#include "cutlass_extensions/vllm_custom_types.cuh" +#include "cutlass_extensions/cute_utils.cuh" + +// this file extends: +// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h +// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t +// as well as adds interleaved numeric array converters for specific types. +// (interleaved numeric array converters can be more efficient for subbyte +// types) + +namespace cutlass { + +// InterleavedNumericArrayConverter is like NumericArrayConverter but also +// deinterleaves converted elements based on IlvBlkLayout, interleaving can +// make subbyte converts more efficient by allowing for efficient extraction +// of subbyte elements from a 32bit register. +template +struct InterleavedNumericArrayConverter { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + CUTE_INVALID_CONTROL_PATH( + "InterleavedNumericArrayConverter not implemented\n"); + return {}; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +template +struct InterleavedNumericArrayConverter< + IlvBlkLayout, T, S, N, Round, + std::enable_if_t()>> { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return Converter::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// TODO (LucasWilkinson): Implement +// for Array <= Array + +// .... + +template +struct ArrayConverterPacked32Bit { + using result_type = Array; + using source_type = Array; + + using result_packed_8_t = Array; + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_8_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + static_assert(N % 2 == 0, "N must be a multiple of 2"); + static_assert(cutlass::sizeof_bits_v >= 4); // TODO: add 16 packed sources + static_assert(32 % cutlass::sizeof_bits_v == 0); + static constexpr auto src_elems_per_32bit_reg = + 32 / cutlass::sizeof_bits_v; + + // Maybe not Valid. ScalarConverter will not actually work unless + // NumericConverter is implemented. However it won't be used + // anyways since we assert N % 2 == 0, just here for compliance with + // VectorizedConverter. + using ScalarConverter = NumericConverter; + + template + CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) { + if constexpr (sizeof(PackedSrc) == 1) { + return static_cast(reinterpret_cast(source)); + } else if constexpr (sizeof(PackedSrc) == 2) { + return static_cast(reinterpret_cast(source)); + } else { + static_assert(sizeof(PackedSrc) == 4); + return reinterpret_cast(source); + } + } + + // The core converter uses bit tricks to construct a known FP16 number, then + // does a subtraction in FP16 for the final result. + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert(PackedSrcType::kElements == PackedResultType::kElements); + static_assert(PackedResultType::kElements == 2 || + PackedResultType::kElements == 4 || + PackedResultType::kElements == 8, + "Invalid PackedResultType must be 2, 4 or 8."); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + return RegConvert32bit::template convert(to_reg(source)); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + ArrayConverterPacked32Bit; + + if constexpr (src_elems_per_32bit_reg >= 8) { + detail::VectorizedConverter::convert< + ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t, + src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source); + } else if constexpr (src_elems_per_32bit_reg >= 4) { + detail::VectorizedConverter::convert(result, source); + } else { + detail::VectorizedConverter::convert(result, source); + } + + return result; + } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + // Below constructs the following temporary: + // fp16s_01 = {0x00, i4_01, 0x00, i4_01} + // fp16s_23 = {0x00, i4_23, 0x00, i4_23} + // fp16s_45 = {0x00, i4_45, 0x00, i4_45} + // fp16s_67 = {0x00, i4_67, 0x00, i4_67} + // We use inline asm instead of __byte_perm intrinsic since we don't want + // the documented (& 0x7) on the index. NVCC might be able to optimize it + // out since the index is a constexpr, but we choose to be safe about it + // here. + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for F16 -> I4 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src), "n"(0), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a fp16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the FP16 to the correct value for the + // FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)}, + // where x1 in the high nibble and x0 is the low nibble then using hfma + // to subtract 1032 from that + // The AND does the following: + // 1) Clear the set bits for the int4 we will ignore. + // We use lop3 so that we can use 1 instruction for AND and XOR. + static constexpr uint32_t xor_mask = 0x64006400; + static constexpr uint32_t and_mask = 0xFFF0FF0F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 hfmas that do the following: + // {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032} + // = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032} + static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032} + static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1} + + const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); + const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032} + // For high nibble: + // {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16} + // - {72, 72} + static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024} + // For high nibble: + // {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64} + static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + + uint32_t const prmt_indices[2] = {0x5150, 0x5352}; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(r[ii]) + : "r"(src), "n"(start_byte_for_fp16), + "r"(prmt_indices[ii])); + } + + // -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes + static constexpr uint32_t bias_rep = 0x64806480; + const half2& bias = reinterpret_cast(bias_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + PackedResultType r; + + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of + // u8x4 source and stores the result in r (without introducing extra + // cvt.u32.u8 instruction) + uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; + uint32_t* result_as_int = reinterpret_cast(&r); + for (int ii = 0; ii < PackedResultType::kElements; ++ii) { + result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]); + // Subtract the magic number 0x4B000000 from tmp in floating-point + // arithmetic to obtain final result + r[ii] -= (8388608.f + 128.f); // fold in -128 bias + } + + return r; + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) { + // Hold output BF16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for uint4b8_t -> BF16 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a BF16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the BF16 to the correct value for the + // BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)} + // and subtracting 136 to get {x1, x0} + static constexpr uint32_t xor_mask = 0x43004300; + static constexpr uint32_t and_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 bfmas that do the following: + // high BF16: + // hi_bf16 - 136, lo_bf16 - 136 + + // This is the BF16 {136, 136} represented as an integer. + static constexpr uint32_t bias_rep = 0x43084308; + const __nv_bfloat162& bias = + reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136} + static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128} + static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + // Not Valid, not supported, only here to satisfy the interface and to avoid + // a compile error. ScalarConverter will not actually work until + // NumericConverter is + // implemented + using ScalarConverter = + NumericConverter; + + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private " + "convert dispatch."); + + NumericArrayConverter + convert_uint8_to_f32; + Array tmp = + convert_uint8_to_f32(source); + NumericArrayConverter + convert_f32_to_bf16_; + return convert_f32_to_bf16_(tmp); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h new file mode 100644 index 0000000..a634e1c --- /dev/null +++ b/csrc/dispatch_utils.h @@ -0,0 +1,35 @@ +/* + * Adapted from + * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h + */ +#pragma once + +#include + +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu new file mode 100644 index 0000000..7a7a25d --- /dev/null +++ b/csrc/layernorm_kernels.cu @@ -0,0 +1,357 @@ +#include +#include +#include + +#include "dispatch_utils.h" +#ifndef USE_ROCM + #include + #include + #include + #include +#else + #include + #include + #include + #include + +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; +#endif + +namespace vllm { + +// TODO(woosuk): Further optimize this kernel. +template +__global__ void rms_norm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float)input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; + } +} + +/* Converter structs for the conversion from torch types to HIP/CUDA types, + and the associated type conversions within HIP/CUDA. These helpers need + to be implemented for now because the relevant type conversion + operators/constructors are not consistently implemented by HIP/CUDA, so + a generic conversion via type casts cannot be implemented. + + Each struct should have the member static constexpr bool `exists`: + If false, the optimized kernel is not used for the corresponding torch type. + If true, the struct should be fully defined as shown in the examples below. + */ +template +struct _typeConvert { + static constexpr bool exists = false; +}; + +#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) +// CUDA < 12.0 runs into issues with packed type conversion +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __half; + using packed_hip_type = __half2; + + __device__ static inline float convert(hip_type x) { return __half2float(x); } + __device__ static inline float2 convert(packed_hip_type x) { + return __half22float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2half_rn(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22half2_rn(x); + } +}; + + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// CUDA_ARCH < 800 does not have BF16 support +// TODO: Add in ROCm support once public headers handle bf16 maturely +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __nv_bfloat16; + using packed_hip_type = __nv_bfloat162; + + __device__ static inline float convert(hip_type x) { + return __bfloat162float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __bfloat1622float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2bfloat16(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22bfloat162_rn(x); + } +}; + #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= + // 12000)) + +/* Vector POD struct to generate vectorized and packed FP16/BF16 ops + for appropriate specializations of fused_add_rms_norm_kernel. + Only functions that are necessary in that kernel are implemented. + Alignment to 16 bytes is required to use 128-bit global memory ops. + */ +template +struct alignas(16) _f16Vec { + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ + static_assert(width > 0 && (width & (width - 1)) == 0, + "Width is not a positive power of 2!"); + using Converter = _typeConvert; + using T1 = typename Converter::hip_type; + using T2 = typename Converter::packed_hip_type; + T1 data[width]; + + __device__ _f16Vec& operator+=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] += other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] *= other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const float scale) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); + temp_f.x *= scale; + temp_f.y *= scale; + T2 temp = Converter::convert(temp_f); + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float temp = Converter::convert(data[i]) * scale; + data[i] = Converter::convert(temp); + } + } + return *this; + } + + __device__ float sum_squares() const { + float result = 0.0f; + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + result += z.x * z.x + z.y * z.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float x = Converter::convert(data[i]); + result += x * x; + } + } + return result; + } +}; + +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. */ +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16Vec>); + static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); + + const int vec_hidden_size = hidden_size / width; + __shared__ float s_variance; + float variance = 0.0f; + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = + reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = + reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = input_v[id]; + temp += residual_v[id]; + variance += temp.sum_squares(); + residual_v[id] = temp; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = residual_v[id]; + temp *= s_variance; + temp *= weight_v[idx]; + input_v[id] = temp; + } +} + +/* Generic fused_add_rms_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + scalar_t z = input[blockIdx.x * hidden_size + idx]; + z += residual[blockIdx.x * hidden_size + idx]; + float x = (float)z; + variance += x * x; + residual[blockIdx.x * hidden_size + idx] = z; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)residual[blockIdx.x * hidden_size + idx]; + input[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; + } +} + +} // namespace vllm + +void rms_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + }); +} + +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>(input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), epsilon, \ + num_tokens, hidden_size); \ + }); + +void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto res_ptr = reinterpret_cast(residual.data_ptr()); + auto wt_ptr = reinterpret_cast(weight.data_ptr()); + bool ptrs_are_aligned = + inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0) { + LAUNCH_FUSED_ADD_RMS_NORM(8); + } else { + LAUNCH_FUSED_ADD_RMS_NORM(0); + } +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h new file mode 100644 index 0000000..a251730 --- /dev/null +++ b/csrc/moe/moe_ops.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu new file mode 100644 index 0000000..de9747b --- /dev/null +++ b/csrc/moe/topk_softmax_kernels.cu @@ -0,0 +1,506 @@ +/* + * Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu + * Copyright (c) 2024, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "../cuda_compat.h" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +namespace vllm { +namespace moe { + +/// Aligned array type +template < + typename T, + /// Number of elements in the array + int N, + /// Alignment requirement in bytes + int Alignment = sizeof(T) * N +> +class alignas(Alignment) AlignedArray { + float data[N]; +}; + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) +{ + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) + { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) + { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) + { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = val; + } +} + +template +__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, + int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +{ + + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + thread_kvp.key = 0; + thread_kvp.value = -1.f; // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) + { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) + { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) + { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) + { + // Ignore experts the node isn't responsible for with expert parallelism + const int expert = result_kvp.key; + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, + int* source_rows, const int k, const int start_expert, const int end_expert) +{ + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) + { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. + // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. + using AccessType = AlignedArray; + + // Finally, we pull in the data from global mem + float row_chunk[VPT]; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) + { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + float thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) + { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + for (int k_idx = 0; k_idx < k; ++k_idx) + { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) + { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) + { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) + { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); + int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) + { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) + { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) + { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) + { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; + } + } + } +} + +namespace detail +{ +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants +{ + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, + int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) +{ + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topkGatingSoftmax<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); +} + +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indicies, \ + token_expert_indices, num_tokens, topk, 0, num_experts, \ + stream); + +void topkGatingSoftmaxKernelLauncher( + const float* gating_output, + float* topk_weights, + int* topk_indicies, + int* token_expert_indices, + float* softmax_workspace, + const int num_tokens, + const int num_experts, + const int topk, + cudaStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + switch (num_experts) { + case 1: + LAUNCH_SOFTMAX(1, WARPS_PER_TB); + break; + case 2: + LAUNCH_SOFTMAX(2, WARPS_PER_TB); + break; + case 4: + LAUNCH_SOFTMAX(4, WARPS_PER_TB); + break; + case 8: + LAUNCH_SOFTMAX(8, WARPS_PER_TB); + break; + case 16: + LAUNCH_SOFTMAX(16, WARPS_PER_TB); + break; + case 32: + LAUNCH_SOFTMAX(32, WARPS_PER_TB); + break; + case 64: + LAUNCH_SOFTMAX(64, WARPS_PER_TB); + break; + case 128: + LAUNCH_SOFTMAX(128, WARPS_PER_TB); + break; + case 256: + LAUNCH_SOFTMAX(256, WARPS_PER_TB); + break; + default: { + TORCH_CHECK(softmax_workspace != nullptr, + "softmax_workspace must be provided for num_experts that are not a power of 2."); + static constexpr int TPB = 256; + moeSoftmax<<>>( + gating_output, nullptr, softmax_workspace, num_experts); + moeTopK<<>>( + softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices, + num_experts, topk, 0, num_experts); + } + } +} + +} // namespace moe +} // namespace vllm + +void topk_softmax( + torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& token_expert_indices, // [num_tokens, topk] + torch::Tensor& gating_output) // [num_tokens, num_experts] +{ + const int num_experts = gating_output.size(-1); + const int num_tokens = gating_output.numel() / num_experts; + const int topk = topk_weights.size(-1); + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + const bool needs_workspace = !is_pow_2 || num_experts > 256; + const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp new file mode 100644 index 0000000..86e42af --- /dev/null +++ b/csrc/moe/torch_bindings.cpp @@ -0,0 +1,12 @@ +#include "core/registration.h" +#include "moe_ops.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + // Apply topk softmax to the gating outputs. + m.def( + "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " + "token_expert_indices, Tensor gating_output) -> ()"); + m.impl("topk_softmax", torch::kCUDA, &topk_softmax); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu new file mode 100644 index 0000000..1f8d75d --- /dev/null +++ b/csrc/moe_align_block_size_kernels.cu @@ -0,0 +1,134 @@ +#include +#include + +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +namespace vllm { + +namespace { +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, + int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; +} +} // namespace + +template +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = + shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = + shared_mem + (num_experts + 1) * + num_experts; // 1d tensor with shape (num_experts + 1) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } +} +} // namespace vllm + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t shared_mem = + ((num_experts + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); + + // set dynamic shared mem + auto kernel = vllm::moe_align_block_size_kernel; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); + kernel<<<1, num_experts, shared_mem, stream>>>( + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel()); + }); +} diff --git a/csrc/ops.h b/csrc/ops.h new file mode 100644 index 0000000..6bf0cff --- /dev/null +++ b/csrc/ops.h @@ -0,0 +1,218 @@ +#pragma once + +#include +#include + +#include "core/scalar_type.hpp" + +void paged_attention_v1( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); + +void paged_attention_v2( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); + +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + double epsilon); + +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, double epsilon); + +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int64_t head_size, + torch::Tensor& cos_sin_cache, bool is_neox); + +void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int64_t head_size, + torch::Tensor& cos_sin_cache, bool is_neox, + int64_t rot_dim, + torch::Tensor& cos_sin_cache_offsets); + +void silu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_new(torch::Tensor& out, torch::Tensor& input); + +void gelu_fast(torch::Tensor& out, torch::Tensor& input); + +void gelu_quick(torch::Tensor& out, torch::Tensor& input); + +void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables); + +#ifndef USE_ROCM +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const std::vector& codebook_partition_sizes, + const std::optional& bias); + +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, const torch::Tensor& codebooks, + const std::vector& codebook_partition_sizes); + +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int64_t split_k_iters); + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy); + +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t size_m, int64_t size_n, int64_t size_k); + +namespace machete { + +std::vector supported_schedules( + vllm::ScalarTypeTorchPtr const& btype); + +torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, + vllm::ScalarTypeTorchPtr const& btype, + c10::optional const& scales, + c10::optional const& zeros, + c10::optional group_size, + c10::optional const& C, + c10::optional alpha, c10::optional beta, + c10::optional schedule); + +torch::Tensor prepack_B(torch::Tensor const& B, + vllm::ScalarTypeTorchPtr const& btype); + +}; // namespace machete + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, + int64_t size_k); + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& b_zeros, + torch::Tensor& g_idx, torch::Tensor& perm, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool has_zp, + bool use_fp32_reduce); + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits); + +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, + int64_t size_n, int64_t num_bits); + +torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, + int64_t n); + +torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, + int64_t type, int64_t row); + +torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, + int64_t row); + +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k); + +bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); + +void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias); + +void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& azp, + c10::optional const& bias); + +torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, + torch::Tensor const& b_q_weight, + torch::Tensor const& s_tok, + torch::Tensor const& s_ch, + torch::Tensor const& s_group, + torch::Tensor& workspace, int64_t size_m, + int64_t size_n, int64_t size_k); +#endif + +void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& scale); + +void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor& scales); + +void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table); + +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int64_t bit); + +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); + +void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& scale); + +void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor& scale); + +void dynamic_per_token_scaled_fp8_quant( + torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, + c10::optional const& scale_ub); + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); + +#ifndef USE_ROCM +using fptr_t = int64_t; +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int64_t rank, + bool full_nvlink); +bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, + bool full_nvlink); +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out); +void dispose(fptr_t _fa); +int64_t meta_size(); +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets); +std::tuple> get_graph_buffer_ipc_meta( + fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets); +#endif diff --git a/csrc/paged_attention/README.md b/csrc/paged_attention/README.md deleted file mode 100644 index 1b66514..0000000 --- a/csrc/paged_attention/README.md +++ /dev/null @@ -1,3 +0,0 @@ -Note: Current version of paged attention kernels adapted from https://github.com/vllm-project/vllm 0.2.7 - -For any changes from vLLM, please mark with `//`(start) and `//<\fms>`(end) and explain the changes in this README. \ No newline at end of file diff --git a/csrc/paged_attention/attention/attention_kernels.cu b/csrc/paged_attention/attention/attention_kernels.cu deleted file mode 100644 index c7659ca..0000000 --- a/csrc/paged_attention/attention/attention_kernels.cu +++ /dev/null @@ -1,882 +0,0 @@ -/* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * Copyright (c) 2023, The vLLM team. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifdef USE_ROCM -#include -#endif - -#include -#include -#include - -#include "attention_dtypes.h" -#include "attention_utils.cuh" - -#include - -#ifndef USE_ROCM -#define WARP_SIZE 32 -#else -#define WARP_SIZE warpSize -#endif -#define MAX(a, b) ((a) > (b) ? (a) : (b)) -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) - -namespace vllm { - -// Utility function for attention softmax. -template -inline __device__ float block_sum(float* red_smem, float sum) { - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - - // Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += VLLM_SHFL_XOR_SYNC(sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < NUM_WARPS) { - sum = red_smem[lane]; - } - - // Parallel reduction inside the warp. -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - sum += VLLM_SHFL_XOR_SYNC(sum, mask); - } - - // Broadcast to other threads. - return VLLM_SHFL_SYNC(sum, 0); -} - -// TODO(woosuk): Merge the last two dimensions of the grid. -// Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE = 0> // Zero means no partitioning. -__device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { - const int seq_idx = blockIdx.y; - const int partition_idx = blockIdx.z; - const int max_num_partitions = gridDim.z; - constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { - // No work to do. Terminate the thread block. - return; - } - - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; - - // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); - const int num_blocks = end_block_idx - start_block_idx; - - // [start_token_idx, end_token_idx) is the range of tokens to process. - const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); - const int num_tokens = end_token_idx - start_token_idx; - - constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS - assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int thread_idx = threadIdx.x; - const int warp_idx = thread_idx / WARP_SIZE; - const int lane = thread_idx % WARP_SIZE; - - const int head_idx = blockIdx.x; - const int num_heads = gridDim.x; - const int num_queries_per_kv = num_heads / num_kv_heads; - const int kv_head_idx = head_idx / num_queries_per_kv; - const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; - - // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread group - // fetch or compute 16 bytes at a time. - // For example, if the size of a thread group is 4 and the data type is half, - // then the vector size is 16 / (4 * sizeof(half)) == 2. - constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); - using K_vec = typename Vec::Type; - using Q_vec = typename Vec::Type; - - constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; - constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; - - const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; - const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - - // Load the query to registers. - // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... - // th vectors of the query, and so on. - // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. - const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; -#pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { - const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); - } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs - - // Memory planning. - extern __shared__ char shared_mem[]; - // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. - float* logits = reinterpret_cast(shared_mem); - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // x == THREAD_GROUP_SIZE * VEC_SIZE - // Each thread group fetches x elements from the key at a time. - constexpr int x = 16 / sizeof(scalar_t); - float qk_max = -FLT_MAX; - - // Iterate over the key blocks. - // Each warp fetches a block of keys for each iteration. - // Each thread group in a warp fetches a key from the block, and computes - // dot product with the query. - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); - - // Load a key to registers. - // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. - for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - K_vec k_vecs[NUM_VECS_PER_THREAD]; - -#pragma unroll - for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; - const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; - const int offset1 = (vec_idx * VEC_SIZE) / x; - const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - } - - // Compute dot product. - // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; - - if (thread_group_offset == 0) { - // Store the partial reductions to shared memory. - // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= context_len; - logits[token_idx - start_token_idx] = mask ? 0.f : qk; - // Update the max value. - qk_max = mask ? qk_max : fmaxf(qk_max, qk); - } - } - } - - // Perform reduction across the threads in the same warp to get the - // max qk value for each "warp" (not across the thread block yet). - // The 0-th thread of each thread group already has its max qk value. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = qk_max; - } - __syncthreads(); - - // TODO(woosuk): Refactor this part. - // Get the max qk value for the sequence. - qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } - // Broadcast the max qk value to all threads. - qk_max = VLLM_SHFL_SYNC(qk_max, 0); - - // Get the sum of the exp values. - float exp_sum = 0.f; - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - float val = __expf(logits[i] - qk_max); - logits[i] = val; - exp_sum += val; - } - exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - - // Compute softmax. - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - logits[i] *= inv_sum; - } - __syncthreads(); - - // If partitioning is enabled, store the max logit and exp_sum. - if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; - *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; - *exp_sums_ptr = exp_sum; - } - - // Each thread will fetch 16 bytes from the value cache at a time. - constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); - using V_vec = typename Vec::Type; - using L_vec = typename Vec::Type; - using Float_L_vec = typename FloatVec::Type; - - constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; - constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); - - // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. - float accs[NUM_ROWS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - accs[i] = 0.f; - } - - scalar_t zero_value; - zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); - const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); - - const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE) { - const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - V_vec v_vec = *reinterpret_cast(v_ptr + offset); - if (block_idx == num_context_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the context, - // we should explicitly zero out the values since they may contain NaNs. - // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 - scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); -#pragma unroll - for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; - } - } - accs[i] += dot(logits_vec, v_vec); - } - } - } - - // Perform reduction within each warp. -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - float acc = accs[i]; -#pragma unroll - for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += VLLM_SHFL_XOR_SYNC(acc, mask); - } - accs[i] = acc; - } - - // NOTE(woosuk): A barrier is required because the shared memory space for logits - // is reused for the output. - __syncthreads(); - - // Perform reduction across warps. - float* out_smem = reinterpret_cast(shared_mem); -#pragma unroll - for (int i = NUM_WARPS; i > 1; i /= 2) { - int mid = i / 2; - // Upper warps write to shared memory. - if (warp_idx >= mid && warp_idx < i) { - float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - dst[row_idx] = accs[i]; - } - } - } - __syncthreads(); - - // Lower warps update the output. - if (warp_idx < mid) { - const float* src = &out_smem[warp_idx * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - accs[i] += src[row_idx]; - } - } - } - __syncthreads(); - } - - // Write the final output. - if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - from_float(*(out_ptr + row_idx), accs[i]); - } - } - } -} - -// Grid: (num_heads, num_seqs, 1). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS> -__global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); -} - -// Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> -__global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride); -} - -// Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> -__global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { - const int num_heads = gridDim.x; - const int head_idx = blockIdx.x; - const int seq_idx = blockIdx.y; - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - if (num_partitions == 1) { - // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { - out_ptr[i] = tmp_out_ptr[i]; - } - // Terminate the thread block. - return; - } - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int warp_idx = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - - // Size: 2 * num_partitions. - extern __shared__ char shared_mem[]; - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // Load max logits to shared memory. - float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; - float max_logit = -FLT_MAX; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - const float l = max_logits_ptr[i]; - shared_max_logits[i] = l; - max_logit = fmaxf(max_logit, l); - } - __syncthreads(); - - // Get the global max logit. - // Reduce within the warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = max_logit; - } - __syncthreads(); - // Reduce across warps. - max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); - } - // Broadcast the max value to all threads. - max_logit = VLLM_SHFL_SYNC(max_logit, 0); - - // Load rescaled exp sums to shared memory. - float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; - float global_exp_sum = 0.0f; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - float l = shared_max_logits[i]; - float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); - global_exp_sum += rescaled_exp_sum; - shared_exp_sums[i] = rescaled_exp_sum; - } - __syncthreads(); - global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); - const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); - - // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -#pragma unroll - for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { - float acc = 0.0f; - for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; - } - from_float(out_ptr[i], acc); - } -} - -} // namespace vllm - -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); - -// TODO(woosuk): Tune NUM_THREADS. -template< - typename T, - int BLOCK_SIZE, - int NUM_THREADS = 128> -void paged_attention_v1_launcher( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, - const c10::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_context_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len - // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); - - dim3 grid(num_heads, num_seqs, 1); - dim3 block(NUM_THREADS); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V1(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V1(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V1(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V1(112); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V1(128); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V1(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ - alibi_slopes); - -// NOTE(woosuk): To reduce the compilation time, we omitted block sizes -// 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, 32); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } - -void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] - int block_size, - int max_context_len, - const c10::optional& alibi_slopes) { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } -} - -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - context_lens_ptr, \ - max_num_partitions); - -template< - typename T, - int BLOCK_SIZE, - int NUM_THREADS = 128, - int PARTITION_SIZE = 512> -void paged_attention_v2_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, - const c10::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - - // For paged attention v2 kernel. - dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); - // For paged attention v2 reduce kernel. - dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); - - dim3 block(NUM_THREADS); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V2(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V2(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V2(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V2(112); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V2(128); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V2(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ - alibi_slopes); - -// NOTE(woosuk): To reduce the compilation time, we omitted block sizes -// 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, 32); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } - -void paged_attention_v2( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] - int block_size, - int max_context_len, - const c10::optional& alibi_slopes) { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } -} - -#undef WARP_SIZE -#undef MAX -#undef MIN -#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/paged_attention/cache.h b/csrc/paged_attention/cache.h deleted file mode 100644 index 05508e9..0000000 --- a/csrc/paged_attention/cache.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include - -#include -#include - -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - const std::map>& block_mapping); - -void reshape_and_cache( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping); diff --git a/csrc/paged_attention/cache_kernels.cu b/csrc/paged_attention/cache_kernels.cu deleted file mode 100644 index b4d2cf1..0000000 --- a/csrc/paged_attention/cache_kernels.cu +++ /dev/null @@ -1,190 +0,0 @@ -#include -#include -#include - -#include "cuda_compat.h" -#include "dispatch_utils.h" - -#include -#include -#include -#include - -namespace vllm { - -// Grid: (num_layers, num_pairs) -template -__global__ void copy_blocks_kernel( - int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { - const int layer_idx = blockIdx.x; - const int pair_idx = blockIdx.y; - - scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); - scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); - int64_t src_block_number = block_mapping[2 * pair_idx]; - int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; - - const int64_t src_block_offset = src_block_number * numel_per_block; - const int64_t dst_block_offset = dst_block_number * numel_per_block; - for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { - int64_t src_offset = src_block_offset + i; - int64_t dst_offset = dst_block_offset + i; - key_cache[dst_offset] = key_cache[src_offset]; - } - for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { - int64_t src_offset = src_block_offset + i; - int64_t dst_offset = dst_block_offset + i; - value_cache[dst_offset] = value_cache[src_offset]; - } -} - -} // namespace vllm - -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - const std::map>& block_mapping) { - int num_layers = key_caches.size(); - TORCH_CHECK(num_layers == value_caches.size()); - if (num_layers == 0) { - return; - } - torch::Device cache_device = key_caches[0].device(); - TORCH_CHECK(cache_device.is_cuda()); - - // Create data structures for the kernel. - // Create an array of pointers to the key and value caches. - int64_t key_cache_ptrs[num_layers]; - int64_t value_cache_ptrs[num_layers]; - for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); - value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); - } - // Create block mapping array. - std::vector block_mapping_vec; - for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - for (int64_t dst_block_number : pair.second) { - block_mapping_vec.push_back(src_block_number); - block_mapping_vec.push_back(dst_block_number); - } - } - int64_t* block_mapping_array = block_mapping_vec.data(); - int num_pairs = block_mapping_vec.size() / 2; - - // Move the data structures to the GPU. - // NOTE: This synchronizes the CPU and GPU. - torch::Tensor key_cache_ptrs_tensor = torch::from_blob( - key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor value_cache_ptrs_tensor = torch::from_blob( - value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor block_mapping_tensor = torch::from_blob( - block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); - - // Launch the kernel. - const int numel_per_block = key_caches[0][0].numel(); - dim3 grid(num_layers, num_pairs); - dim3 block(std::min(1024, numel_per_block)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { - vllm::copy_blocks_kernel<<>>( - key_cache_ptrs_tensor.data_ptr(), - value_cache_ptrs_tensor.data_ptr(), - block_mapping_tensor.data_ptr(), - numel_per_block); - })); -} - -namespace vllm { - -template -__global__ void reshape_and_cache_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x) { - const int64_t token_idx = blockIdx.x; - const int64_t slot_idx = slot_mapping[token_idx]; - if (slot_idx < 0) { - // Padding token that should be ignored. - return; - } - - const int64_t block_idx = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - - const int n = num_heads * head_size; - for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int64_t src_key_idx = token_idx * key_stride + i; - const int64_t src_value_idx = token_idx * value_stride + i; - - const int head_idx = i / head_size; - const int head_offset = i % head_size; - const int x_idx = head_offset / x; - const int x_offset = head_offset % x; - - const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; - key_cache[tgt_key_idx] = key[src_key_idx]; - value_cache[tgt_value_idx] = value[src_value_idx]; - } -} - -} // namespace vllm - -void reshape_and_cache( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping) // [num_tokens] -{ - int num_tokens = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); - int block_size = key_cache.size(3); - int x = key_cache.size(4); - - int key_stride = key.stride(0); - int value_stride = value.stride(0); - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), - "reshape_and_cache_kernel", - [&] { - vllm::reshape_and_cache_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - slot_mapping.data_ptr(), - key_stride, - value_stride, - num_heads, - head_size, - block_size, - x); - }); -} \ No newline at end of file diff --git a/csrc/paged_attention/cuda_utils.h b/csrc/paged_attention/cuda_utils.h deleted file mode 100644 index 6d87231..0000000 --- a/csrc/paged_attention/cuda_utils.h +++ /dev/null @@ -1,7 +0,0 @@ -#pragma once - -#include - -int get_device_attribute( - int attribute, - int device_id); \ No newline at end of file diff --git a/csrc/paged_attention/cuda_utils_kernels.cu b/csrc/paged_attention/cuda_utils_kernels.cu deleted file mode 100644 index fc27f97..0000000 --- a/csrc/paged_attention/cuda_utils_kernels.cu +++ /dev/null @@ -1,17 +0,0 @@ -#ifdef USE_ROCM - #include -#endif -int get_device_attribute( - int attribute, - int device_id) -{ - int device, value; - if (device_id < 0) { - cudaGetDevice(&device); - } - else { - device = device_id; - } - cudaDeviceGetAttribute(&value, static_cast(attribute), device); - return value; -} \ No newline at end of file diff --git a/csrc/paged_attention/dispatch_utils.h b/csrc/paged_attention/dispatch_utils.h deleted file mode 100644 index b8dc9e8..0000000 --- a/csrc/paged_attention/dispatch_utils.h +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Adapted from - * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h - */ -#pragma once - -#include - -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) - -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) \ No newline at end of file diff --git a/csrc/paged_attention/ops.h b/csrc/paged_attention/ops.h deleted file mode 100644 index 21fe203..0000000 --- a/csrc/paged_attention/ops.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once - -#include - -void paged_attention_v1( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes); - -void paged_attention_v2( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes); diff --git a/csrc/paged_attention/pybind.cpp b/csrc/paged_attention/pybind.cpp deleted file mode 100644 index 0b9d558..0000000 --- a/csrc/paged_attention/pybind.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include "cache.h" -#include "cuda_utils.h" -#include "ops.h" -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // vLLM custom ops - pybind11::module ops = m.def_submodule("attn_ops", "vLLM attn operators"); - - // Attention ops - ops.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - ops.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); - - // Cache ops - pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); - - cache_ops.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - - // Cuda utils - pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); - cuda_utils.def( - "get_device_attribute", - &get_device_attribute, - "Gets the specified device attribute."); -} diff --git a/csrc/paged_attention/reduction_utils.cuh b/csrc/paged_attention/reduction_utils.cuh deleted file mode 100644 index bc35aa0..0000000 --- a/csrc/paged_attention/reduction_utils.cuh +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh - * Copyright (c) 2023, The vLLM team. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -namespace vllm { - -template -__inline__ __device__ T warpReduceSum(T val) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(0xffffffff, val, mask, 32); - return val; -} - -/* Calculate the sum of all elements in a block */ -template -__inline__ __device__ T blockReduceSum(T val) { - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - val = warpReduceSum(val); - - if (lane == 0) - shared[wid] = val; - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); - val = warpReduceSum(val); - return val; -} - -} // namespace vllm diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu new file mode 100644 index 0000000..97184a8 --- /dev/null +++ b/csrc/pos_encoding_kernels.cu @@ -0,0 +1,203 @@ +#include +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +namespace vllm { + +template +inline __device__ void apply_token_rotary_embedding( + scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, const int head_size, const int num_heads, + const int num_kv_heads, const int rot_dim, const int token_idx, + const int64_t query_stride, const int64_t key_stride) { + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding( + query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } +} + +template +__global__ void rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); +} + +template +__global__ void batched_rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] + // or [num_tokens] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; + const scalar_t* cache_ptr = + cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); +} + +} // namespace vllm + +void rotary_embedding( + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int64_t head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { + int64_t num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size); + } + }); +} + +/* +Batched version of rotary embedding, pack multiple LoRAs together +and process in batched manner. +*/ +void batched_rotary_embedding( + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int64_t head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox, int64_t rot_dim, + torch::Tensor& cos_sin_cache_offsets // [num_tokens] +) { + int64_t num_tokens = cos_sin_cache_offsets.size(0); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } + }); +} diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu new file mode 100644 index 0000000..0e537dd --- /dev/null +++ b/csrc/prepare_inputs/advance_step.cu @@ -0,0 +1,131 @@ +/* + * The goal of this GPU kernel is to advance input tensors on the GPU directly + * PR: https://github.com/vllm-project/vllm/pull/6338 + * Current restrictions: + * 1. Specialized for DraftModelRunner + * 2. Supports flash_attn only + */ + +#include "advance_step.cuh" + +namespace prepare_inputs { + +// +template +__global__ void advance_step_kernel(int num_seqs, int num_queries, + int block_size, long* input_tokens_ptr, + long const* sampled_token_ids_ptr, + long* input_positions_ptr, + int* seq_lens_ptr, long* slot_mapping_ptr, + int const* block_tables_ptr, + int64_t const block_tables_stride) { + int num_query_blocks = div_ceil(num_queries, num_threads); + + if (blockIdx.x >= num_query_blocks) { + return; + } + + int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + + if (cur_query_id >= num_queries) { + return; + } + + // Update input_tokens + input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; + + int seq_len = seq_lens_ptr[cur_query_id]; + int next_seq_len = seq_len + 1; + int next_input_pos = next_seq_len - 1; + + // Update seq_lens + seq_lens_ptr[cur_query_id] = next_seq_len; + // Update input_positions + input_positions_ptr[cur_query_id] = next_input_pos; + + int const* seq_block_tables_ptr = + block_tables_ptr + block_tables_stride * cur_query_id; + + int block_index = next_input_pos / block_size; + int block_offset = next_input_pos % block_size; + + int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slot_mapping_ptr[cur_query_id] = slot_num; +} + +inline void verify_tensor(std::string const& name, torch::Tensor& t, + int64_t const size_0, int64_t const size_1, + c10::ScalarType const type) { + bool size_0_cond = true; + if (size_0 != -1) { + size_0_cond = t.size(0) == size_0; + } + + bool size_1_cond = true; + if (size_1 != -1) { + size_1_cond = t.size(1) == size_1; + } + + bool is_contiguous = t.is_contiguous(); + bool same_type = t.dtype() == type; + + bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; + if (!pass) { + TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), + " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), + " is not as expected: shape = [", size_0, ", ", size_1, + "], type = ", type); + } +} + +void advance_step(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables) { // type: int + + if (logging) { + printf("advance_step:\n"); + printf(" num_seqs = %d\n", num_seqs); + printf(" num_queries = %d\n", num_queries); + printf(" block_size = %d\n", block_size); + } + // Verify all tensors + verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); + verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, + at::kLong); + verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); + verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + + int dev = sampled_token_ids.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + advance_step_kernel<<>>( + num_seqs, num_queries, block_size, + reinterpret_cast(input_tokens.data_ptr()), + reinterpret_cast(sampled_token_ids.data_ptr()), + reinterpret_cast(input_positions.data_ptr()), + reinterpret_cast(seq_lens.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0)); +} + +} // namespace prepare_inputs + +void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables) { + prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, + sampled_token_ids, input_positions, seq_lens, + slot_mapping, block_tables); +} \ No newline at end of file diff --git a/csrc/prepare_inputs/advance_step.cuh b/csrc/prepare_inputs/advance_step.cuh new file mode 100644 index 0000000..f215746 --- /dev/null +++ b/csrc/prepare_inputs/advance_step.cuh @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace prepare_inputs { + +static constexpr int max_threads = 256; +static constexpr bool logging = false; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +} // namespace prepare_inputs diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu new file mode 100644 index 0000000..79cd2c6 --- /dev/null +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -0,0 +1,597 @@ +/* + * Modified by Neural Magic + * Adapted from https://github.com/Vahe1994/AQLM + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace vllm { +namespace aqlm { + +__global__ void Code1x16MatVec( + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m, + const int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; + } + } + + int b_gl_rd = 0; + int c_gl_wr = a_gl_rd; + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + + __shared__ int4 sh_b[32 * 9]; + float res = 0; + + int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); + while (iters--) { + // We pad shared memory to avoid bank conflicts during reads + __syncthreads(); + for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + } + __syncthreads(); + b_gl_rd += 32 * 8; + + int b_sh_rd = 9 * (threadIdx.x % 32); + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); +#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t dec[4]; + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); + half2* a = reinterpret_cast(&dec); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2 res2 = {}; +#pragma unroll + for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2); + res += __half2float(res2.x) + __half2float(res2.y); + b_sh_rd++; + } + a_gl_rd += 32; + } + } + + if (pred) { +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); + if (threadIdx.x % 32 == 0) + reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); + } +} + +__global__ void Code2x8MatVec( + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. + +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; + } + } + + int b_gl_rd = 0; + int c_gl_wr = a_gl_rd; + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + int lane = threadIdx.x % 8; + + extern __shared__ int4 sh[]; + int4* sh_b = sh; + int4* sh_code = sh_b + 32 * 9; + int4* sh_code0 = sh_code; + int4* sh_code1 = sh_code + 256 * 8; + + for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { + int4 dec = codebook[i]; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; + } + __syncthreads(); + + float res = 0; + + int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); + while (iters--) { + // We pad shared memory to avoid bank conflicts during reads + __syncthreads(); + for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + } + __syncthreads(); + b_gl_rd += 32 * 8; + + int b_sh_rd = 9 * (threadIdx.x % 32); + if (pred && a_gl_rd < a_gl_end) { + const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); +#pragma unroll + for (int i = 0; i < 8; i++) { + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2 res2 = {}; +#pragma unroll + for (int j = 0; j < 4; j++) + res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); + res += __half2float(res2.x) + __half2float(res2.y); + b_sh_rd++; + } + a_gl_rd += 32; + } + } + + if (pred) { +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); + if (threadIdx.x % 32 == 0) + reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); + } +} + +__global__ void Code1x16Dequant( + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long, sums to m. + const int codebook_stride // as int4 +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; + } + } + + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + + int c_gl_stride = prob_k / 8; + int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; + + int iters = (prob_k / 8 - 1) / (8 * 32) + 1; + while (iters--) { + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); +#pragma unroll + for (int i = 0; i < 8; i++) { + int4 chunk; + auto dec = reinterpret_cast(&chunk); + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); + + C[a_gl_rd * 8 + i] = chunk; + } + } + a_gl_rd += 32; + } +} + +__global__ void Code2x8Dequant( + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; + } + } + + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + int lane = threadIdx.x % 8; + + int c_gl_stride = prob_k / 8; + int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; + + extern __shared__ int4 sh[]; + int4* sh_code = sh; + int4* sh_code0 = sh_code; + int4* sh_code1 = sh_code + 256 * 8; + + for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { + int4 dec = codebook[i]; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; + } + __syncthreads(); + + int iters = (prob_k / 8 - 1) / (8 * 32) + 1; + while (iters--) { + if (pred && a_gl_rd < a_gl_end) { + const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); +#pragma unroll + for (int i = 0; i < 8; i++) { + int4 chunk; + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); + C[a_gl_rd * 8 + i] = chunk; + } + } + a_gl_rd += 32; + } +} + +inline int ceildiv(int a, int b) { return (a + b - 1) / b; } + +const int THREAD_M = 16; + +void code1x16_matvec_cuda(const void* __restrict__ A, + const void* __restrict__ B, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code1x16MatVec<<>>( + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); +} + +void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + int shared = 16 * (2 * 256 * 8 + 32 * 9); + cudaFuncSetAttribute(Code2x8MatVec, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code2x8MatVec<<>>( + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); +} + +void code1x16_dequant_cuda( + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code1x16Dequant<<>>( + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long. + codebook_stride // as int4. + ); +} + +// Dequantizes the code and codebook into weights. +void code2x8_dequant_cuda( + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + int shared = 16 * (2 * 256 * 8 + 32 * 9); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + cudaFuncSetAttribute(Code2x8Dequant, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); + Code2x8Dequant<<>>( + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, codebook_stride); +} + +int codebook_stride(const torch::Tensor& codebooks) { + return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); +} + +void code1x16_matvec( + const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes // cumulative sizes of A spanning each + // codebook, at most 3 long. +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + int prob_m = C.size(0); + int prob_k = B.size(0); + + code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + codebook_stride(codebook)); +} + +torch::Tensor code1x16_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { + auto input_sizes = input.sizes(); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); + + for (int i = 0; i < flat_input.size(0); ++i) { + auto input_vec = flat_input.index({i}); + auto output_vec = flat_output.index({i}); + code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); + } + flat_output *= scales.flatten().unsqueeze(0); + + if (bias.has_value()) { + flat_output += bias->unsqueeze(0); + } + + auto output_sizes = input_sizes.vec(); + output_sizes.pop_back(); + output_sizes.push_back(-1); + auto output = flat_output.reshape(output_sizes); + return output; +} + +void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B, + torch::Tensor& C, const torch::Tensor& codebook, + const int4 codebook_a_sizes) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + int prob_m = C.size(0); + int prob_k = B.size(0); + code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + 2 * codebook_stride(codebook)); +} + +torch::Tensor code2x8_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { + auto input_sizes = input.sizes(); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); + + for (int i = 0; i < flat_input.size(0); ++i) { + auto input_vec = flat_input.index({i}); + auto output_vec = flat_output.index({i}); + code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); + } + flat_output *= scales.flatten().unsqueeze(0); + if (bias.has_value()) { + flat_output += bias->unsqueeze(0); + } + + auto output_sizes = input_sizes.vec(); + output_sizes.pop_back(); + output_sizes.push_back(-1); + auto output = flat_output.reshape(output_sizes); + return output; +} + +// Accumulate the partition sizes. +int4 accumulate_sizes(const std::vector& codebook_partition_sizes) { + int4 cumulative_sizes; + auto cumulative_size = &cumulative_sizes.x; + size_t i = 0; + int last = 0; + assert(codebook_partition_sizes.size() <= 4); + for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) { + *cumulative_size = codebook_partition_sizes[i] + last; + last = *cumulative_size; + } + // fill in the rest with unreachable. + for (; i < 4; ++i, ++cumulative_size) { + *cumulative_size = last * 10; + } + return cumulative_sizes; +} + +} // namespace aqlm +} // namespace vllm + +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const std::vector& codebook_partition_sizes, + const std::optional& bias) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); + + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); + int const entries = codebooks.size(1); + + if (nbooks == 1 && entries == (1 << 16)) { + return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); + } + if (nbooks == 2 && entries == (1 << 8)) { + return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); + } + + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") + return {}; +} + +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, const torch::Tensor& codebooks, + const std::vector& codebook_partition_sizes) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); + + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); + int const entries = codebooks.size(1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(codes)); + int rows = codes.size(1); + int cols = codes.size(0); + + auto in_features = codes.size(1) * 8; + auto out_features = codes.size(0); + + assert(out_features == std::accumulate(codebook_partition_sizes.begin(), + codebook_partition_sizes.end(), 0)); + + auto weights = torch::empty({out_features, in_features}, + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device())); + + if (nbooks == 1 && entries == (1 << 16)) { + vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation.) weights *= + // scales.index({"...", 0, 0}); + + return weights; + } + + if (nbooks == 2 && entries == (1 << 8)) { + vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation) weights *= + // scales.index({"...", 0, 0}); + + return weights; + } + + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") + return {}; +} diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh new file mode 100644 index 0000000..5fa4b5f --- /dev/null +++ b/csrc/quantization/awq/dequantize.cuh @@ -0,0 +1,102 @@ +/* +Adapted from https://github.com/mit-han-lab/llm-awq +Modified from NVIDIA FasterTransformer: +https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +@article{lin2023awq, + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} +} +*/ + +#pragma once + +namespace vllm { +namespace awq { + +__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + uint4 result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. In + // addition, I exploit the fact that sub and fma have the same throughput in + // order to convert elt_23 and elt_67 to fp16 without having to shift them to + // the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW + // dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit + // float2half instructions if I use the half2 ctor. In this case, I chose + // performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + + return result; +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +} // namespace awq +} // namespace vllm diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu new file mode 100644 index 0000000..9da724a --- /dev/null +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -0,0 +1,526 @@ +/* +Adapted from https://github.com/mit-han-lab/llm-awq +@article{lin2023awq, + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} +} + */ + +#include +#include + +#include "dequantize.cuh" + +#include + +namespace vllm { +namespace awq { + +template +__global__ void __launch_bounds__(64) + gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, + half* __restrict__ A, int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, int M, int IC, + int OC, half* __restrict__ C) { + // Only support matrix n = 64 or 128 + assert(N == 64 || N == 128); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + static constexpr uint32_t ZERO = 0x0; + float C_warp[32]; + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (N + 8)]; + + int j_factors1 = ((OC + N - 1) / N); + int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); + + half A_shared_warp[8]; + half B_shared_warp[N / 4]; + for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[(j_0_4_init * 8) + i] = 0.0; + } + } + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / N; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + bool ld_A_flag = + (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + // bool wb_C_flag = (threadIdx.x / 4) < M; + + half* A_ptr = + A + + (((int)blockIdx_y) / j_factors1 * 16 + + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * + IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; + // Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8)) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = + C + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + if (ld_A_flag) { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); + } else { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = + *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + /* + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && + threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, + B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, + B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + } + */ + // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); + int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { + // B: 32 x 136 (128+8) float16 + // each warp: 32 x 4 + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus + // zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * + // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) + // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * + // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * + // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * + // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = + *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + + // - zero and * scale + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = + // q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + /* + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == + 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", + B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + } + */ + + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = + B_loaded_fp16; + } + __syncthreads(); + + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + + (((((int)threadIdx.x) & 15) * 40) + + ((((int)threadIdx.x) >> 4) * 8))))); + + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(A_shared_warp + 0))[0]), + "=r"(((unsigned*)(A_shared_warp + 0))[1]), + "=r"(((unsigned*)(A_shared_warp + 0))[2]), + "=r"(((unsigned*)(A_shared_warp + 0))[3]) + : "r"(addr)); + } + + for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + + (((int)threadIdx.y) * (N / 2))) + + (ax1_0 * 16))])) + + (((((int)threadIdx.x) & 15) * (N + 8)) + + ((((int)threadIdx.x) >> 4) * 8))))); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr)); + } + } + for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + #else + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + + #endif + } + } + } + + // TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + } + } + } +#endif +} + +__global__ void __launch_bounds__(64) + dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, + int* __restrict__ zeros, half* __restrict__ C, int G) { + static constexpr uint32_t ZERO = 0x0; + half B_shared[32 * (128 + 8)]; + + half* B_shared_ptr2 = B_shared; + + int N = blockDim.x * gridDim.x; // 2 + int col = (blockIdx.x * blockDim.x + threadIdx.x); + int row = blockIdx.y * blockDim.y + threadIdx.y; + int index1 = 8 * col + 8 * row * N; + half* C_ptr2 = C + index1; + + int index2 = col + row * N; + int* B_ptr2 = B + index2; + + int index3 = col + (int)(row / G) * N; + int* zeros_ptr2 = zeros + index3; + int index4 = 8 * col + (int)(row / G) * N * 8; + half* scaling_factors_ptr2 = scaling_factors + index4; + + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); + + uint32_t B_loaded = *(uint32_t*)B_ptr2; + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + + *(uint4*)B_shared_ptr2 = B_loaded_fp16; + + for (int i = 0; i < 8; ++i) { + *(C_ptr2 + i) = B_shared[i]; + } +} + +} // namespace awq +} // namespace vllm + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy) { + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + int G = in_c / _scaling_factors.size(0); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx == 0) { + x_thread = qout_c; + } + if (thy == 0) { + y_thread = in_c; + } + if (thx == 0 && thy == 0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } + + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + + auto options = torch::TensorOptions() + .dtype(_scaling_factors.dtype()) + .device(_scaling_factors.device()); + at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_thread, y_thread); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::awq::dequantize_weights<<>>( + kernel, scaling_factors, zeros, de_kernel, G); + + return _de_kernel; +} + +// in_feats: M, IC [float16] +// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] +// scaling_factors: IC // G, OC [float16] +// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] +// assume that batch_size < 16 for now + +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int64_t split_k_iters) { + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions() + .dtype(_in_feats.dtype()) + .device(_in_feats.device()); + at::Tensor _out_feats = + torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + int group_size = num_in_channels / _scaling_factors.size(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } else if (num_out_channels % 64 == 0) { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * + split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + return _out_feats.sum(0); +} diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu new file mode 100644 index 0000000..616fc14 --- /dev/null +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -0,0 +1,125 @@ +#include +#include +#include + +#include "../../dispatch_utils.h" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +static inline __device__ int8_t float_to_int8_rn(float x) { +#ifdef USE_ROCM + static const float i8_min = + static_cast(std::numeric_limits::min()); + static const float i8_max = + static_cast(std::numeric_limits::max()); + // round + float dst = std::nearbyint(x); + // saturate + dst = std::clamp(dst, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +namespace vllm { + +template +__global__ void static_scaled_int8_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = float_to_int8_rn( + static_cast(input[token_idx * hidden_size + i]) / scale); + } +} + +template +__global__ void dynamic_scaled_int8_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + float absmax_val = 0.0f; + float const zero = 0.0f; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + float val = static_cast(input[token_idx * hidden_size + i]); + val = val > zero ? val : -val; + absmax_val = val > absmax_val ? val : absmax_val; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + float const block_absmax_val_maybe = + BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); + __shared__ float block_absmax_val; + if (tid == 0) { + block_absmax_val = block_absmax_val_maybe; + scale[token_idx] = block_absmax_val / 127.0f; + } + __syncthreads(); + + float const tmp_scale = 127.0f / block_absmax_val; + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = float_to_int8_rn( + static_cast(input[token_idx * hidden_size + i]) * tmp_scale); + } +} + +} // namespace vllm + +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& scale) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scale.numel() == 1); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + dim3 const grid(num_tokens); + dim3 const block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { + vllm::static_scaled_int8_quant_kernel + <<>>(input.data_ptr(), + out.data_ptr(), + scale.data_ptr(), hidden_size); + }); +} + +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor& scales) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + dim3 const grid(num_tokens); + dim3 const block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { + vllm::dynamic_scaled_int8_quant_kernel + <<>>(input.data_ptr(), + out.data_ptr(), + scales.data_ptr(), hidden_size); + }); +} diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/cutlass_w8a8/Epilogues.md new file mode 100644 index 0000000..aae0415 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/Epilogues.md @@ -0,0 +1,147 @@ +# CUTLASS Epilogues + +## Introduction +This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. + +Currently, we only support symmetric quantization for weights, +and symmetric and asymmetric quantization for activations. +Both can be quantized per-tensor or per-channel (weights) / per-token (activations). + +There are 4 epilogues: +1. ScaledEpilogue: symmetric quantization for activations, no bias. +1. ScaledEpilogueBias: symmetric quantization for activations, supports bias. +1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias. +1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias. + +We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. +Instead, if no bias is passed, the epilogue will use 0 as the bias. +That induces a redundant addition operation (and runtime check), but the performance impact is minor. + +## Underlying Linear Algebra + +More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975). + +If $` \widehat X `$ is the quantized $` X `$, our matrices become the following + +```math +A = s_a (\widehat A - J_a z_a) +``` +```math +B = s_b \widehat B +``` +```math +D = A B + C +``` +```math +D = s_a s_b \widehat D + C +``` + +Here, D is the output of the GEMM, and C is the bias. +A is the activations and supports asymmetric quantization, +and B is the weights and only supports symmetric quantization. +$ s_a $ and $s_b$ are the scales for activations and weights, respectively. +$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A. +Additional epilogues would be required to support asymmetric quantization for weights. + +Expanding further, we can calculate $` \widehat D `$ as follows: + +```math +A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B +``` +```math +A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) +``` +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` + +Note that $` \widehat A \widehat B `$ is the raw output of the GEMM, +and $` J_a \widehat B `$ is known ahead of time. +Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$. + +## Epilogues + +### ScaledEpilogue +This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` +```math +D = s_a s_b \widehat D +``` +```math +D = s_a s_b \widehat A \widehat B +``` + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). + +### ScaledEpilogueBias +This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` +```math +D = s_a s_b \widehat D + C +``` +```math +D = s_a s_b \widehat A \widehat B + C +``` + + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +### ScaledEpilogueAzp +This epilogue computes the asymmetric per-tensor quantization for activations with bias. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` +```math +D = s_a s_b \widehat D + C +``` +```math +D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C +``` + +Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. +That is precomputed and stored in `azp_with_adj` as a row-vector. + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-tensor as the zero-points are per-tensor. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. + +### ScaledEpilogueAzpPerToken +This epilogue computes the asymmetric per-token quantization for activations with bias. + +The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. +That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-token as the zero-points are per-token. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector). +- `azp` is the zero-point (`z_a`), is per-token (column-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. + +The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): +``` +out = scale_a * scale_b * (Dq - azp_adj * azp) + bias +``` diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp new file mode 100644 index 0000000..d407d66 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp @@ -0,0 +1,496 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either +// row/column or scalar broadcasting where the tensor being loaded from is +// always passed in via a device pointer. This lets one compiled kernel handle +// all cases of per-tensor or per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graph +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cute/tensor.hpp" + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->row_broadcast) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = *(params_ptr->ptr_row); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrZeroBroadcast { + + // This struct has been modified to remove null_default (because it's always 0) + struct Arguments { + Element const* ptr_row = nullptr; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row != nullptr) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are broadcasting 0 + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = Element{0}; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + + if (params_ptr->col_broadcast) { + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); + } else { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if (pred(i)) { + dst_v(i) = *(params_ptr->ptr_col); + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp new file mode 100644 index 0000000..58b1e8f --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcast { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row)); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcast { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col)); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + params + ); + } +}; + +} diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp new file mode 100644 index 0000000..bf04bb4 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/common.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + cutlassGetStatusString(status)) \ + } + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} + diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu new file mode 100644 index 0000000..ee801e1 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -0,0 +1,198 @@ +#include +#include +#include "cutlass/cutlass.h" + +#include "scaled_mm_c2x.cuh" +#include "scaled_mm_c2x_sm75_dispatch.cuh" +#include "scaled_mm_c2x_sm80_dispatch.cuh" +#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" +#include "scaled_mm_c2x_sm89_int8_dispatch.cuh" + +/* + This file defines quantized GEMM operations using the CUTLASS 2.x API, for + NVIDIA GPUs with SM versions prior to sm90 (Hopper). +*/ + +template