Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ extern "C" {
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_DELTA_NET,

GGML_OP_UNARY,

Expand Down Expand Up @@ -2460,6 +2461,15 @@ extern "C" {
bool lower,
bool uni);

GGML_API struct ggml_tensor * ggml_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q, // [S_k, n_tokens, H_k, n_seqs] - Query (pre-permuted)
struct ggml_tensor * k, // [S_k, n_tokens, H_k, n_seqs] - Key (pre-permuted)
struct ggml_tensor * v, // [S_v, n_tokens, H_v, n_seqs] - Value (pre-permuted)
struct ggml_tensor * g, // [n_tokens, 1, H_k, n_seqs] - Gate logits (pre-permuted)
struct ggml_tensor * beta, // [1, n_tokens, H_k, n_seqs] - Beta (pre-permuted)
struct ggml_tensor * state); // [S_v, S_v*H_v, 1, n_seqs] - Recurrent state

// custom operators

typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rwkv_wkv7(params, tensor);
} break;
case GGML_OP_DELTA_NET:
{
ggml_compute_forward_delta_net(params, tensor);
} break;
case GGML_OP_SOLVE_TRI:
{
ggml_compute_forward_solve_tri(params, tensor);
Expand Down Expand Up @@ -2339,6 +2343,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
case GGML_OP_DELTA_NET:
{
n_tasks = n_threads;
} break;
Expand Down
133 changes: 133 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10091,6 +10091,139 @@ void ggml_compute_forward_rwkv_wkv7(
}
}

// ggml_compute_forward_delta_net

static void ggml_compute_forward_delta_net_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const ggml_tensor * src3 = dst->src[3];
const ggml_tensor * src4 = dst->src[4];
const ggml_tensor * src5 = dst->src[5];

const int64_t head_dim = src0->ne[0];
const int64_t n_tokens = src0->ne[1];
const int64_t n_heads = src0->ne[2];
const int64_t n_seqs = src0->ne[3];

const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs;

const float * q_data = (const float *) src0->data;
const float * k_data = (const float *) src1->data;
const float * v_data = (const float *) src2->data;
const float * g_data = (const float *) src3->data;
const float * beta_data = (const float *) src4->data;
const float * state_in = (const float *) src5->data;
float * out_data = (float *) dst->data;
float * state_out = out_data + output_size;

const int ith = params->ith;
const int nth = params->nth;

const int64_t total_heads = n_heads * n_seqs;
const int64_t heads_per_thread = (total_heads + nth - 1) / nth;
const int64_t h_start = ith * heads_per_thread;
const int64_t h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;

const float eps = 1e-12f;
const float scale = 1.0f / sqrtf((float)head_dim);

float * v_new_buf = (float *)malloc(head_dim * sizeof(float));
if (!v_new_buf) {
return;
}

for (int64_t h_idx = h_start; h_idx < h_end; h_idx++) {
const int64_t batch_idx = h_idx / n_heads;
const int64_t head_idx = h_idx % n_heads;

const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
const int64_t qkv_token_stride = head_dim;
const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
const int64_t out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
const int64_t out_token_stride = head_dim * n_heads;

for (int64_t i = 0; i < head_dim * head_dim; i++) {
state_out[state_head_offset + i] = state_in[state_head_offset + i];
}

float * state = state_out + state_head_offset;

for (int64_t t = 0; t < n_tokens; t++) {
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;

float g_val = g_data[g_head_offset + t];
float beta_raw = beta_data[g_head_offset + t];

float q_norm_sq = 0.0f, k_norm_sq = 0.0f;
for (int64_t i = 0; i < head_dim; i++) {
q_norm_sq += q_t[i] * q_t[i];
k_norm_sq += k_t[i] * k_t[i];
}
float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);

float beta_val = 1.0f / (1.0f + expf(-beta_raw));
float decay = expf(fminf(g_val, 50.0f));

float attn_score = 0.0f;
for (int64_t i = 0; i < head_dim; i++) {
attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
}

float * out_t = out_data + out_head_offset + t * out_token_stride;

for (int64_t row = 0; row < head_dim; row++) {
float v_prime = 0.0f;
float out_val = 0.0f;

for (int64_t col = 0; col < head_dim; col++) {
float k_col = k_t[col] * k_norm_inv;
float q_col = q_t[col] * q_norm_inv * scale;
float s = state[row + col * head_dim];

v_prime += s * k_col * beta_val * decay;
out_val += s * q_col * decay;
}

float v_new = v_t[row] * beta_val - v_prime;
v_new_buf[row] = v_new;
out_t[row] = out_val + v_new * attn_score;
}

for (int64_t col = 0; col < head_dim; col++) {
float k_col = k_t[col] * k_norm_inv;
for (int64_t row = 0; row < head_dim; row++) {
float s = state[row + col * head_dim];
s = decay * s + v_new_buf[row] * k_col;
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
}
}
}
}

free(v_new_buf);
}

void ggml_compute_forward_delta_net(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
ggml_compute_forward_delta_net_f32(params, dst);
break;
default:
GGML_ABORT("fatal error");
}
}

// ggml_compute_forward_map_custom1

void ggml_compute_forward_map_custom1(
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Expand Down
Loading