Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2534,8 +2534,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;

static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
if (hsv >= 192) {
static uint32_t get_fa_scalar_num_large_rows(const vk_device& device, uint32_t hsk, uint32_t hsv) {
if ((device->vendor_id == VK_VENDOR_ID_INTEL) &&
(device->driver_id == vk::DriverId::eIntelProprietaryWindows)) {
return 2;
} else if (hsv >= 192) {
return 2;
} else if ((hsv | hsk) & 8) {
return 4;
Expand All @@ -2559,7 +2562,7 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
}
}

static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
static std::array<uint32_t, 2> fa_rows_cols(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
GGML_UNUSED(clamp);
GGML_UNUSED(hsv);

Expand All @@ -2570,9 +2573,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if ((hsv | hsk) & 8) {
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
return {get_fa_scalar_num_large_rows(device, hsk, hsv), 64};
} else {
return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
return {get_fa_scalar_num_large_rows(device, hsk, hsv), 32};
}
}
}
Expand Down Expand Up @@ -2601,8 +2604,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
return {64, 64};
}

static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
static uint32_t fa_align(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
return fa_rows_cols(device, path, hsk, hsv, 0, type, small_rows)[1];
}

static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
Expand Down Expand Up @@ -2945,7 +2948,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
};

auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
return {fa_rows_cols(device, path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
};

auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
Expand All @@ -2958,7 +2961,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
auto rows_cols = fa_rows_cols(device, path, hsk, hsv, clamp, type, small_rows);

// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
Expand All @@ -2979,15 +2982,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(device, FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(device, FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \
} \
} \
Expand Down Expand Up @@ -7958,7 +7961,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
const uint32_t Br = get_fa_scalar_num_large_rows(device, hsk, hsv);
const uint32_t Bc = scalar_flash_attention_Bc;

const uint32_t tmpsh = wg_size * sizeof(float);
Expand Down Expand Up @@ -8089,7 +8092,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
max_gqa = get_fa_scalar_num_large_rows(ctx->device, HSK, HSV);
break;
case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
Expand Down Expand Up @@ -8139,7 +8142,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
v_stride /= 4;
}

uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
uint32_t alignment = fa_align(ctx->device, path, HSK, HSV, k->type, small_rows);
bool aligned = (KV % alignment) == 0 &&
// the "aligned" shader variant will forcibly align strides, for performance
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
Expand Down