Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3076,16 +3076,23 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];

if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}

if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];

if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
Expand All @@ -3094,8 +3101,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];

if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
Expand Down
19 changes: 17 additions & 2 deletions ggml/src/ggml-cuda/topk-moe.cu
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
}
}

bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also noticed that clamp is always nullptr. I didn't attempt to fix this.

bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert) {
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];

if (probs != selection_probs) {
return false;
}

float scale = 1.0f;
float max_bias = 0.0f;

Expand All @@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return false;
}

const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;
Expand Down
7 changes: 6 additions & 1 deletion ggml/src/ggml-cuda/topk-moe.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);

bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert);

std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
19 changes: 19 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12889,24 +12889,43 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc

const ggml_tensor * softmax;
const ggml_tensor * weights;
const ggml_tensor * get_rows;
const ggml_tensor * argsort;

switch (mode) {
case TOPK_MOE_EARLY_SOFTMAX_NORM:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 9];
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break;
case TOPK_MOE_EARLY_SOFTMAX:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 4];
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break;
case TOPK_MOE_LATE_SOFTMAX:
softmax = cgraph->nodes[node_idx + 4];
weights = cgraph->nodes[node_idx + 5];
get_rows = cgraph->nodes[node_idx + 2];
argsort = cgraph->nodes[node_idx + 0];
break;
default:
return false;
}

ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];

if (probs != selection_probs) {
return false;
}

const float * op_params = (const float *)softmax->op_params;

float scale = op_params[0];
Expand Down
91 changes: 62 additions & 29 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5118,25 +5118,36 @@ struct test_top_k : public test_case {
}
};

enum MoeGatingFunc {
GATING_FUNC_SOFTMAX,
GATING_FUNC_SIGMOID,
GATING_FUNC_SOFTMAX_WEIGHT,
};

struct test_topk_moe : public test_case {
const std::array<int64_t, 4> ne;
const int n_expert_used;
const bool with_norm;
const bool delayed_softmax;
const bool bias_probs;
const MoeGatingFunc gating_func;
const float scale_w;

test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
int n_expert_used = 1,
bool with_norm = false,
bool delayed_softmax = false) :
bool bias_probs = false,
MoeGatingFunc gating_func = GATING_FUNC_SOFTMAX,
float scale_w = 0.0f) :
ne(ne),
n_expert_used(n_expert_used),
with_norm(with_norm),
delayed_softmax(delayed_softmax) {
bias_probs(bias_probs),
gating_func(gating_func),
scale_w(scale_w) {
GGML_ASSERT(n_expert_used <= ne[0]);
GGML_ASSERT(!(with_norm && delayed_softmax));
}

std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); }

std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
Expand All @@ -5150,28 +5161,47 @@ struct test_topk_moe : public test_case {
const int n_tokens = ne[1];

ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * probs =
(gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) :
(gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits;
ggml_set_name(probs, "probs");

ggml_tensor * selection_probs = probs;
if (bias_probs) {
ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_set_name(exp_probs_b, "exp_probs_b");
selection_probs = ggml_add(ctx, probs, exp_probs_b);
ggml_set_name(selection_probs, "selection_probs");
}

ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_set_name(selected_experts, "selected_experts");

if (delayed_softmax) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
ggml_set_name(weights, "weights");

if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
}

if (with_norm) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
ggml_set_name(weights_sum, "weights_sum");

weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
}

ggml_set_name(out, "out");
return out;
if (scale_w) {
weights = ggml_scale(ctx, weights, scale_w);
}

ggml_set_name(weights, "weights");
return weights;
}
};

Expand Down Expand Up @@ -7972,19 +8002,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}

for (bool with_norm : {false, true}) {
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm));
for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) {
for (bool with_norm : {false, true}) {
for (bool bias_probs : {false, true}) {
for (float scale_w : {0.0f, 2.0f}) {
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
}
}
}
}

test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));

#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
test_cases.emplace_back(new test_llama(2, true));
Expand Down
Loading