@@ -5118,25 +5118,36 @@ struct test_top_k : public test_case {
51185118 }
51195119};
51205120
5121+ enum MoeGatingFunc {
5122+ GATING_FUNC_SOFTMAX,
5123+ GATING_FUNC_SIGMOID,
5124+ GATING_FUNC_SOFTMAX_WEIGHT,
5125+ };
5126+
51215127struct test_topk_moe : public test_case {
51225128 const std::array<int64_t , 4 > ne;
51235129 const int n_expert_used;
51245130 const bool with_norm;
5125- const bool delayed_softmax;
5131+ const bool bias_probs;
5132+ const MoeGatingFunc gating_func;
5133+ const float scale_w;
51265134
51275135 test_topk_moe (std::array<int64_t , 4 > ne = { 10 , 5 , 1 , 1 },
51285136 int n_expert_used = 1 ,
51295137 bool with_norm = false ,
5130- bool delayed_softmax = false ) :
5138+ bool bias_probs = false ,
5139+ MoeGatingFunc gating_func = GATING_FUNC_SOFTMAX,
5140+ float scale_w = 0 .0f ) :
51315141 ne (ne),
51325142 n_expert_used (n_expert_used),
51335143 with_norm (with_norm),
5134- delayed_softmax (delayed_softmax) {
5144+ bias_probs (bias_probs),
5145+ gating_func (gating_func),
5146+ scale_w (scale_w) {
51355147 GGML_ASSERT (n_expert_used <= ne[0 ]);
5136- GGML_ASSERT (!(with_norm && delayed_softmax));
51375148 }
51385149
5139- std::string vars () override { return VARS_TO_STR4 (ne, n_expert_used, with_norm, delayed_softmax ); }
5150+ std::string vars () override { return VARS_TO_STR6 (ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w ); }
51405151
51415152 std::string op_desc (ggml_tensor * t) override {
51425153 GGML_UNUSED (t);
@@ -5150,28 +5161,47 @@ struct test_topk_moe : public test_case {
51505161 const int n_tokens = ne[1 ];
51515162
51525163 ggml_tensor * logits = ggml_new_tensor (ctx, GGML_TYPE_F32, 4 , ne.data ());
5153- ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max (ctx, logits);
5154- ggml_tensor * selected_experts = ggml_argsort_top_k (ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
5164+ ggml_tensor * probs =
5165+ (gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max (ctx, logits) :
5166+ (gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid (ctx, logits) : logits;
5167+ ggml_set_name (probs, " probs" );
5168+
5169+ ggml_tensor * selection_probs = probs;
5170+ if (bias_probs) {
5171+ ggml_tensor * exp_probs_b = ggml_new_tensor (ctx, GGML_TYPE_F32, 4 , ne.data ());
5172+ ggml_set_name (exp_probs_b, " exp_probs_b" );
5173+ selection_probs = ggml_add (ctx, probs, exp_probs_b);
5174+ ggml_set_name (selection_probs, " selection_probs" );
5175+ }
51555176
5156- ggml_tensor * out = ggml_get_rows (ctx, ggml_reshape_3d (ctx, probs, 1 , n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
5177+ ggml_tensor * selected_experts = ggml_argsort_top_k (ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
5178+ ggml_set_name (selected_experts, " selected_experts" );
51575179
5158- if (delayed_softmax) {
5159- out = ggml_reshape_2d (ctx, out, n_expert_used, n_tokens);
5160- out = ggml_soft_max (ctx, out); // [n_expert_used, n_tokens]
5161- out = ggml_reshape_3d (ctx, out, 1 , n_expert_used, n_tokens);
5180+ ggml_tensor * weights = ggml_get_rows (ctx, ggml_reshape_3d (ctx, probs, 1 , n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
5181+ ggml_set_name (weights, " weights" );
5182+
5183+ if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
5184+ weights = ggml_reshape_2d (ctx, weights, n_expert_used, n_tokens);
5185+ weights = ggml_soft_max (ctx, weights); // [n_expert_used, n_tokens]
5186+ weights = ggml_reshape_3d (ctx, weights, 1 , n_expert_used, n_tokens);
51625187 }
51635188
51645189 if (with_norm) {
5165- out = ggml_reshape_2d (ctx, out, n_expert_used, n_tokens);
5166- ggml_tensor * weights_sum = ggml_sum_rows (ctx, out); // [1, n_tokens]
5190+ weights = ggml_reshape_2d (ctx, weights, n_expert_used, n_tokens);
5191+ ggml_tensor * weights_sum = ggml_sum_rows (ctx, weights); // [1, n_tokens]
5192+ ggml_set_name (weights_sum, " weights_sum" );
51675193
51685194 weights_sum = ggml_clamp (ctx, weights_sum, 6.103515625e-5 , INFINITY);
5169- out = ggml_div (ctx, out , weights_sum); // [n_expert_used, n_tokens]
5170- out = ggml_reshape_3d (ctx, out , 1 , n_expert_used, n_tokens);
5195+ weights = ggml_div (ctx, weights , weights_sum); // [n_expert_used, n_tokens]
5196+ weights = ggml_reshape_3d (ctx, weights , 1 , n_expert_used, n_tokens);
51715197 }
51725198
5173- ggml_set_name (out, " out" );
5174- return out;
5199+ if (scale_w) {
5200+ weights = ggml_scale (ctx, weights, scale_w);
5201+ }
5202+
5203+ ggml_set_name (weights, " weights" );
5204+ return weights;
51755205 }
51765206};
51775207
@@ -7972,19 +8002,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
79728002 }
79738003 }
79748004
7975- for (bool with_norm : {false , true }) {
7976- test_cases.emplace_back (new test_topk_moe ({8 , 22 , 1 , 1 }, 4 , with_norm));
7977- test_cases.emplace_back (new test_topk_moe ({31 , 22 , 1 , 1 }, 8 , with_norm));
7978- test_cases.emplace_back (new test_topk_moe ({32 , 22 , 1 , 1 }, 8 , with_norm));
7979- test_cases.emplace_back (new test_topk_moe ({40 , 22 , 1 , 1 }, 8 , with_norm));
7980- test_cases.emplace_back (new test_topk_moe ({71 , 22 , 1 , 1 }, 8 , with_norm));
7981- test_cases.emplace_back (new test_topk_moe ({128 , 1 , 1 , 1 }, 128 , with_norm));
7982- test_cases.emplace_back (new test_topk_moe ({129 , 1 , 1 , 1 }, 128 , with_norm));
8005+ for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) {
8006+ for (bool with_norm : {false , true }) {
8007+ for (bool bias_probs : {false , true }) {
8008+ for (float scale_w : {0 .0f , 2 .0f }) {
8009+ test_cases.emplace_back (new test_topk_moe ({8 , 22 , 1 , 1 }, 4 , with_norm, bias_probs, gate, scale_w));
8010+ test_cases.emplace_back (new test_topk_moe ({31 , 22 , 1 , 1 }, 8 , with_norm, bias_probs, gate, scale_w));
8011+ test_cases.emplace_back (new test_topk_moe ({32 , 22 , 1 , 1 }, 8 , with_norm, bias_probs, gate, scale_w));
8012+ test_cases.emplace_back (new test_topk_moe ({40 , 22 , 1 , 1 }, 8 , with_norm, bias_probs, gate, scale_w));
8013+ test_cases.emplace_back (new test_topk_moe ({71 , 22 , 1 , 1 }, 8 , with_norm, bias_probs, gate, scale_w));
8014+ test_cases.emplace_back (new test_topk_moe ({128 , 1 , 1 , 1 }, 128 , with_norm, bias_probs, gate, scale_w));
8015+ test_cases.emplace_back (new test_topk_moe ({129 , 1 , 1 , 1 }, 128 , with_norm, bias_probs, gate, scale_w));
8016+ }
8017+ }
8018+ }
79838019 }
79848020
7985- test_cases.emplace_back (new test_topk_moe ({ 8 , 22 , 1 , 1 }, 4 , /* with_norm*/ false , /* delayed_softmax*/ true ));
7986- test_cases.emplace_back (new test_topk_moe ({ 32 , 22 , 1 , 1 }, 8 , /* with_norm*/ false , /* delayed_softmax*/ true ));
7987-
79888021#if 0
79898022 // these tests are disabled to save execution time, sbut they can be handy for debugging
79908023 test_cases.emplace_back(new test_llama(2, true));
0 commit comments