Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ extern "C" {
GGML_OP_SOFT_MAX,
GGML_OP_SOFT_MAX_BACK,
GGML_OP_ROPE,
GGML_OP_ROPE_COMP,
GGML_OP_ROPE_BACK,
GGML_OP_CLAMP,
GGML_OP_CONV_TRANSPOSE_1D,
Expand Down Expand Up @@ -1858,6 +1859,42 @@ extern "C" {
GGML_API void ggml_rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);


enum ggml_rope_ordering {
GGML_ROPE_ORDERING_NORMAL,
GGML_ROPE_ORDERING_NEOX,
};

// demo new RoPE API (NOT yet to be merged)
// RoPE composable API
GGML_API struct ggml_tensor * ggml_rope_comp(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int32_t n_dims,
float freq_base,
enum ggml_rope_ordering ordering);

// TODO: ggml_rope_comp_set_rope_factors

// set YaRN parameters
GGML_API struct ggml_tensor * ggml_rope_comp_set_yarn(
struct ggml_context * ctx,
struct ggml_tensor * node,
int n_ctx_orig,
float freq_base,
float freq_scale, // == 1.0f / scale_factor
float ramp_factor, // usually 1.0f
float attn_factor,
float beta_fast,
float beta_slow);

// set M-RoPE mode
GGML_API struct ggml_tensor * ggml_rope_comp_set_multi(
struct ggml_context * ctx,
struct ggml_tensor * node,
int sections[GGML_MROPE_SECTIONS]);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
GGML_API struct ggml_tensor * ggml_rope_ext_back(
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rope(params, tensor);
} break;
case GGML_OP_ROPE_COMP:
{
ggml_compute_forward_rope_comp(params, tensor);
} break;
case GGML_OP_ROPE_BACK:
{
ggml_compute_forward_rope_back(params, tensor);
Expand Down Expand Up @@ -2294,6 +2298,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_COMP:
case GGML_OP_ROPE_BACK:
case GGML_OP_ADD_REL_POS:
{
Expand Down Expand Up @@ -2812,6 +2817,7 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_ROPE_COMP:
case GGML_OP_ROPE_BACK:
{
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
Expand Down
150 changes: 150 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5817,6 +5817,156 @@ void ggml_compute_forward_rope(
}
}

// ggml_compute_forward_rope_comp

template<typename T> //float or ggml_fp16_t
static void ggml_compute_forward_rope_comp_flt(
const ggml_compute_params * params,
ggml_tensor * dst,
const bool forward) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];

GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_I32);

int32_t n_dims, idx_pair, idx_scale, idx_offset;
float theta_scale, yarn_high, yarn_low, freq_scale, ramp_factor, attn_factor;
int32_t sections[4];

memcpy(&n_dims, (int32_t *)dst->op_params + 0, sizeof(int32_t));
memcpy(&idx_pair, (int32_t *)dst->op_params + 1, sizeof(int32_t));
memcpy(&idx_scale, (int32_t *)dst->op_params + 2, sizeof(int32_t));
memcpy(&idx_offset, (int32_t *)dst->op_params + 3, sizeof(int32_t));
memcpy(&theta_scale, (int32_t *)dst->op_params + 4, sizeof(float));
memcpy(&yarn_high, (int32_t *)dst->op_params + 5, sizeof(float));
memcpy(&yarn_low, (int32_t *)dst->op_params + 6, sizeof(float));
memcpy(&freq_scale, (int32_t *)dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
memcpy(&ramp_factor, (int32_t *)dst->op_params + 9, sizeof(float));
memcpy(&sections[0], (int32_t *)dst->op_params + 10, sizeof(int32_t));
memcpy(&sections[1], (int32_t *)dst->op_params + 11, sizeof(int32_t));
memcpy(&sections[2], (int32_t *)dst->op_params + 12, sizeof(int32_t));
memcpy(&sections[3], (int32_t *)dst->op_params + 13, sizeof(int32_t));

GGML_TENSOR_UNARY_OP_LOCALS

GGML_ASSERT(nb0 == nb00);
GGML_ASSERT(nb0 == sizeof(T));

const int ith = params->ith;
const int nth = params->nth;

const int nr = ggml_nrows(dst);

GGML_ASSERT(n_dims <= ne0);
GGML_ASSERT(n_dims % 2 == 0);

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

// row index used to determine which thread to use
int ir = 0;

// TODO M-RoPE

const float * freq_factors = NULL;
if (src2 != NULL) {
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
freq_factors = (const float *) src2->data;
}

// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f;

const int32_t * pos = (const int32_t *) src1->data;

auto init_cache = [&](float * cache, float theta_base) -> void {
float theta = theta_base;
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
// yarn
{
// Get n-d rotational scaling corrected for extrapolation
float theta_extrap = theta / ff;
float theta_interp = freq_scale * theta_extrap;
theta = theta_interp;
if (ramp_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(yarn_high, yarn_low, i0) * ramp_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
}
}
cache[i0 + 0] = cosf(theta) * attn_factor;
cache[i0 + 1] = sinf(theta) * attn_factor * sin_sign;

theta *= theta_scale;
}
};

for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len

float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
{
const int64_t p = pos[i2];
init_cache(cache, p);
}
// TODO M-RoPE

for (int64_t i1 = idx_offset; i1 < ne1; i1++) { // attn-heads
if (ir++ < ir0) continue;
if (ir > ir1) break;

T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);

rotate_pairs<T>(n_dims, idx_pair, cache, src, dst_data, idx_scale);
// TODO M-RoPE

// fill the remain channels with data from src tensor
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
} //attn-heads
}
}
}

void ggml_compute_forward_rope_comp(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_rope_comp_flt<ggml_fp16_t>(params, dst, false);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_rope_comp_flt<float>(params, dst, false);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_rope_back

void ggml_compute_forward_rope_back(
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * para
void ggml_compute_forward_soft_max(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_soft_max_ext_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope_comp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Expand Down
88 changes: 86 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"SOFT_MAX",
"SOFT_MAX_BACK",
"ROPE",
"ROPE_COMP",
"ROPE_BACK",
"CLAMP",
"CONV_TRANSPOSE_1D",
Expand Down Expand Up @@ -1045,7 +1046,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};

static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -1154,7 +1155,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};

static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -4265,6 +4266,88 @@ void ggml_rope_yarn_corr_dims(
dims[1] = MIN(n_dims - 1, end);
}

// ggml_rope_comp

GGML_API struct ggml_tensor * ggml_rope_comp(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int32_t n_dims,
float freq_base,
enum ggml_rope_ordering ordering) {
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);

GGML_ASSERT(b->ne[0] >= a->ne[2]); // also allow M-RoPE
GGML_ASSERT(b->ne[0] % a->ne[2] == 0);

int32_t idx_pair = 1;
int32_t idx_scale = 1;
if (ordering == GGML_ROPE_ORDERING_NEOX) {
idx_pair = n_dims / 2;
idx_scale = 2;
}

// note: theta = theta_base * theta_scale^i
const float theta_scale = powf(freq_base, -2.0f / (float)n_dims);

int32_t i_zero = 0;
float f_zero = 0.0f;
float f_one = 1.0f;

struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[15];
memset(params, 0, sizeof(params));
memcpy(params + 0, &n_dims, sizeof(int32_t)); // n_dims
memcpy(params + 1, &idx_pair, sizeof(int32_t)); // idx_pair
memcpy(params + 2, &idx_scale, sizeof(int32_t)); // idx_scale
memcpy(params + 3, &i_zero, sizeof(int32_t)); // idx_offset for 2D-RoPE
memcpy(params + 4, &theta_scale, sizeof(float)); // theta_scale
memcpy(params + 5, &f_zero, sizeof(float)); // yarn_high
memcpy(params + 6, &f_zero, sizeof(float)); // yarn_low
memcpy(params + 7, &f_zero, sizeof(float)); // freq_scale
memcpy(params + 8, &f_one, sizeof(float)); // attn_factor
memcpy(params + 9, &f_zero, sizeof(float)); // ramp_factor
memcpy(params + 10, &i_zero, sizeof(int32_t)); // sections[0]
memcpy(params + 11, &i_zero, sizeof(int32_t)); // sections[1]
memcpy(params + 12, &i_zero, sizeof(int32_t)); // sections[2]
memcpy(params + 13, &i_zero, sizeof(int32_t)); // sections[3]
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_ROPE_COMP;
result->src[0] = a;
result->src[1] = b;
result->src[2] = NULL;

return result;
}

struct ggml_tensor * ggml_rope_comp_set_yarn(
struct ggml_context * ctx,
struct ggml_tensor * node,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ramp_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
GGML_UNUSED(ctx);
GGML_ASSERT(node->op == GGML_OP_ROPE_COMP);

const int32_t n_dims = *((int32_t *) node->op_params + 0);

float yarn_high = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
float yarn_low = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));

memcpy((float *) node->op_params + 5, &yarn_high, sizeof(float));
memcpy((float *) node->op_params + 6, &yarn_low, sizeof(float));
memcpy((float *) node->op_params + 7, &freq_scale, sizeof(float));
memcpy((float *) node->op_params + 8, &attn_factor, sizeof(float));
memcpy((float *) node->op_params + 9, &ramp_factor, sizeof(float));
return node;
}

// ggml_rope_back

struct ggml_tensor * ggml_rope_ext_back(
Expand Down Expand Up @@ -6848,6 +6931,7 @@ void ggml_build_backward_expand(
case GGML_OP_GET_ROWS: // row indices not differentiable
case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
case GGML_OP_ROPE: // positions not differentiable
case GGML_OP_ROPE_COMP: // same as for ROPE
ignore_src[1] = true;
break;

Expand Down
Loading
Loading