Skip to content

Commit 21e0267

Browse files
committed
vulkan/cuda: fix topk_moe with exp_probs_b
I updated test_topk_moe to more closely match llm_graph_context::build_moe_ffn and added coverage for exp_probs_b and some other missing combinations. This exposed a bug in both CUDA and Vulkan backends where they were assuming the input to argsort and the input to get_rows are the same. I'd like to optimize this graph in another change, but for now just get it functional. CUDA also had a bug where it got n_experts from the wrong place, leading to GGML_ASSERT failures in some of the new tests.
1 parent 5c8a717 commit 21e0267

File tree

5 files changed

+117
-35
lines changed

5 files changed

+117
-35
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3076,16 +3076,23 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30763076
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
30773077
ggml_tensor * softmax = cgraph->nodes[node_idx];
30783078
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
3079+
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
3080+
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
3081+
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
30793082

3080-
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
3083+
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
30813084
return true;
30823085
}
30833086
}
30843087

30853088
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
30863089
ggml_tensor * softmax = cgraph->nodes[node_idx];
30873090
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
3088-
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
3091+
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
3092+
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
3093+
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
3094+
3095+
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
30893096
return true;
30903097
}
30913098
}
@@ -3094,8 +3101,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30943101
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
30953102
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
30963103
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
3104+
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
3105+
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
3106+
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
30973107

3098-
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
3108+
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
30993109
return true;
31003110
}
31013111
}

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
268268
}
269269
}
270270

271-
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
271+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
272+
const ggml_tensor * weights,
273+
const ggml_tensor * get_rows,
274+
const ggml_tensor * argsort,
275+
const ggml_tensor * clamp,
276+
int n_expert) {
277+
ggml_tensor * probs = get_rows->src[0];
278+
if (probs->op != GGML_OP_RESHAPE) {
279+
return false;
280+
}
281+
probs = probs->src[0];
282+
ggml_tensor * selection_probs = argsort->src[0];
283+
284+
if (probs != selection_probs) {
285+
return false;
286+
}
287+
272288
float scale = 1.0f;
273289
float max_bias = 0.0f;
274290

@@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
288304
return false;
289305
}
290306

291-
const int n_expert = softmax->ne[0];
292307
// n_expert must be a power of 2
293308
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
294309
return false;

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
1111
const bool delayed_softmax = false,
1212
ggml_tensor * weight_clamp = nullptr);
1313

14-
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
14+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
15+
const ggml_tensor * weights,
16+
const ggml_tensor * get_rows,
17+
const ggml_tensor * argsort,
18+
const ggml_tensor * clamp,
19+
int n_expert);
1520

1621
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12889,24 +12889,43 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1288912889

1289012890
const ggml_tensor * softmax;
1289112891
const ggml_tensor * weights;
12892+
const ggml_tensor * get_rows;
12893+
const ggml_tensor * argsort;
1289212894

1289312895
switch (mode) {
1289412896
case TOPK_MOE_EARLY_SOFTMAX_NORM:
1289512897
softmax = cgraph->nodes[node_idx + 0];
1289612898
weights = cgraph->nodes[node_idx + 9];
12899+
get_rows = cgraph->nodes[node_idx + 4];
12900+
argsort = cgraph->nodes[node_idx + 2];
1289712901
break;
1289812902
case TOPK_MOE_EARLY_SOFTMAX:
1289912903
softmax = cgraph->nodes[node_idx + 0];
1290012904
weights = cgraph->nodes[node_idx + 4];
12905+
get_rows = cgraph->nodes[node_idx + 4];
12906+
argsort = cgraph->nodes[node_idx + 2];
1290112907
break;
1290212908
case TOPK_MOE_LATE_SOFTMAX:
1290312909
softmax = cgraph->nodes[node_idx + 4];
1290412910
weights = cgraph->nodes[node_idx + 5];
12911+
get_rows = cgraph->nodes[node_idx + 2];
12912+
argsort = cgraph->nodes[node_idx + 0];
1290512913
break;
1290612914
default:
1290712915
return false;
1290812916
}
1290912917

12918+
ggml_tensor * probs = get_rows->src[0];
12919+
if (probs->op != GGML_OP_RESHAPE) {
12920+
return false;
12921+
}
12922+
probs = probs->src[0];
12923+
ggml_tensor * selection_probs = argsort->src[0];
12924+
12925+
if (probs != selection_probs) {
12926+
return false;
12927+
}
12928+
1291012929
const float * op_params = (const float *)softmax->op_params;
1291112930

1291212931
float scale = op_params[0];

tests/test-backend-ops.cpp

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5118,25 +5118,36 @@ struct test_top_k : public test_case {
51185118
}
51195119
};
51205120

5121+
enum MoeGatingFunc {
5122+
GATING_FUNC_SOFTMAX,
5123+
GATING_FUNC_SIGMOID,
5124+
GATING_FUNC_SOFTMAX_WEIGHT,
5125+
};
5126+
51215127
struct test_topk_moe : public test_case {
51225128
const std::array<int64_t, 4> ne;
51235129
const int n_expert_used;
51245130
const bool with_norm;
5125-
const bool delayed_softmax;
5131+
const bool bias_probs;
5132+
const MoeGatingFunc gating_func;
5133+
const float scale_w;
51265134

51275135
test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
51285136
int n_expert_used = 1,
51295137
bool with_norm = false,
5130-
bool delayed_softmax = false) :
5138+
bool bias_probs = false,
5139+
MoeGatingFunc gating_func = GATING_FUNC_SOFTMAX,
5140+
float scale_w = 0.0f) :
51315141
ne(ne),
51325142
n_expert_used(n_expert_used),
51335143
with_norm(with_norm),
5134-
delayed_softmax(delayed_softmax) {
5144+
bias_probs(bias_probs),
5145+
gating_func(gating_func),
5146+
scale_w(scale_w) {
51355147
GGML_ASSERT(n_expert_used <= ne[0]);
5136-
GGML_ASSERT(!(with_norm && delayed_softmax));
51375148
}
51385149

5139-
std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
5150+
std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); }
51405151

51415152
std::string op_desc(ggml_tensor * t) override {
51425153
GGML_UNUSED(t);
@@ -5150,28 +5161,47 @@ struct test_topk_moe : public test_case {
51505161
const int n_tokens = ne[1];
51515162

51525163
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
5153-
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
5154-
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
5164+
ggml_tensor * probs =
5165+
(gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) :
5166+
(gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits;
5167+
ggml_set_name(probs, "probs");
5168+
5169+
ggml_tensor * selection_probs = probs;
5170+
if (bias_probs) {
5171+
ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
5172+
ggml_set_name(exp_probs_b, "exp_probs_b");
5173+
selection_probs = ggml_add(ctx, probs, exp_probs_b);
5174+
ggml_set_name(selection_probs, "selection_probs");
5175+
}
51555176

5156-
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
5177+
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
5178+
ggml_set_name(selected_experts, "selected_experts");
51575179

5158-
if (delayed_softmax) {
5159-
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
5160-
out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens]
5161-
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
5180+
ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
5181+
ggml_set_name(weights, "weights");
5182+
5183+
if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
5184+
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
5185+
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
5186+
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
51625187
}
51635188

51645189
if (with_norm) {
5165-
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
5166-
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
5190+
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
5191+
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
5192+
ggml_set_name(weights_sum, "weights_sum");
51675193

51685194
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
5169-
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
5170-
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
5195+
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
5196+
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
51715197
}
51725198

5173-
ggml_set_name(out, "out");
5174-
return out;
5199+
if (scale_w) {
5200+
weights = ggml_scale(ctx, weights, scale_w);
5201+
}
5202+
5203+
ggml_set_name(weights, "weights");
5204+
return weights;
51755205
}
51765206
};
51775207

@@ -7972,19 +8002,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
79728002
}
79738003
}
79748004

7975-
for (bool with_norm : {false, true}) {
7976-
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
7977-
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm));
7978-
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
7979-
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm));
7980-
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm));
7981-
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
7982-
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm));
8005+
for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) {
8006+
for (bool with_norm : {false, true}) {
8007+
for (bool bias_probs : {false, true}) {
8008+
for (float scale_w : {0.0f, 2.0f}) {
8009+
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w));
8010+
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
8011+
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
8012+
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
8013+
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
8014+
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
8015+
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
8016+
}
8017+
}
8018+
}
79838019
}
79848020

7985-
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
7986-
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
7987-
79888021
#if 0
79898022
// these tests are disabled to save execution time, sbut they can be handy for debugging
79908023
test_cases.emplace_back(new test_llama(2, true));

0 commit comments

Comments
 (0)