Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ extern "C" {
GGML_OP_SOFT_MAX,
GGML_OP_SOFT_MAX_BACK,
GGML_OP_ROPE,
GGML_OP_ROPE_COMP,
GGML_OP_ROPE_BACK,
GGML_OP_CLAMP,
GGML_OP_CONV_TRANSPOSE_1D,
Expand Down Expand Up @@ -1858,6 +1859,44 @@ extern "C" {
GGML_API void ggml_rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);


enum ggml_rope_ordering {
GGML_ROPE_ORDERING_NORMAL,
GGML_ROPE_ORDERING_NEOX,
};

// RoPE composable API
GGML_API struct ggml_tensor * ggml_rope_comp(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b, // pos must be F32
int32_t n_dims,
float freq_base,
enum ggml_rope_ordering ordering);

GGML_API struct ggml_tensor * ggml_rope_comp_set_freq_factors(
struct ggml_context * ctx,
struct ggml_tensor * node,
struct ggml_tensor * freq_factors);

// set YaRN parameters
GGML_API struct ggml_tensor * ggml_rope_comp_set_yarn(
struct ggml_context * ctx,
struct ggml_tensor * node,
int n_ctx_orig,
float freq_base,
float freq_scale, // == 1.0f / scale_factor
float ramp_factor, // usually 1.0f
float attn_factor,
float beta_fast,
float beta_slow);

// set M-RoPE mode
GGML_API struct ggml_tensor * ggml_rope_comp_set_multi(
struct ggml_context * ctx,
struct ggml_tensor * node,
int sections[GGML_MROPE_SECTIONS]);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
GGML_API struct ggml_tensor * ggml_rope_ext_back(
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rope(params, tensor);
} break;
case GGML_OP_ROPE_COMP:
{
ggml_compute_forward_rope_comp(params, tensor);
} break;
case GGML_OP_ROPE_BACK:
{
ggml_compute_forward_rope_back(params, tensor);
Expand Down Expand Up @@ -2294,6 +2298,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_COMP:
case GGML_OP_ROPE_BACK:
case GGML_OP_ADD_REL_POS:
{
Expand Down Expand Up @@ -2812,6 +2817,7 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_ROPE_COMP:
case GGML_OP_ROPE_BACK:
{
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
Expand Down
148 changes: 148 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5817,6 +5817,154 @@ void ggml_compute_forward_rope(
}
}

// ggml_compute_forward_rope_comp

template<typename T> //float or ggml_fp16_t
static void ggml_compute_forward_rope_comp_flt(
const ggml_compute_params * params,
ggml_tensor * dst,
const bool forward) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];

GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);

int32_t n_dims, idx_pair, idx_scale, idx_offset;
float theta_scale, yarn_high, yarn_low, freq_scale, ramp_factor, attn_factor;
int32_t sections[4];

memcpy(&n_dims, (int32_t *)dst->op_params + 0, sizeof(int32_t));
memcpy(&idx_pair, (int32_t *)dst->op_params + 1, sizeof(int32_t));
memcpy(&idx_scale, (int32_t *)dst->op_params + 2, sizeof(int32_t));
memcpy(&idx_offset, (int32_t *)dst->op_params + 3, sizeof(int32_t));
memcpy(&theta_scale, (int32_t *)dst->op_params + 4, sizeof(float));
memcpy(&yarn_high, (int32_t *)dst->op_params + 5, sizeof(float));
memcpy(&yarn_low, (int32_t *)dst->op_params + 6, sizeof(float));
memcpy(&freq_scale, (int32_t *)dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
memcpy(&ramp_factor, (int32_t *)dst->op_params + 9, sizeof(float));
memcpy(&sections[0], (int32_t *)dst->op_params + 10, sizeof(int32_t));
memcpy(&sections[1], (int32_t *)dst->op_params + 11, sizeof(int32_t));
memcpy(&sections[2], (int32_t *)dst->op_params + 12, sizeof(int32_t));
memcpy(&sections[3], (int32_t *)dst->op_params + 13, sizeof(int32_t));

GGML_TENSOR_UNARY_OP_LOCALS

GGML_ASSERT(nb0 == nb00);
GGML_ASSERT(nb0 == sizeof(T));

const int ith = params->ith;
const int nth = params->nth;

const int nr = ggml_nrows(dst);

GGML_ASSERT(n_dims <= ne0);
GGML_ASSERT(n_dims % 2 == 0);

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

// row index used to determine which thread to use
int ir = 0;

// TODO M-RoPE

const float * freq_factors = NULL;
if (src2 != NULL) {
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
freq_factors = (const float *) src2->data;
}

// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f;

const float * pos = (const float *) src1->data;

auto init_cache = [&](float * cache, float p) -> void {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float freq_factor = freq_factors ? freq_factors[i0/2] : 1.0f;

float theta = p * powf(theta_scale, i0/2) / freq_factor;
const float theta_extrap = theta;
const float theta_interp = freq_scale * theta;

if (ramp_factor != 0.0f) {
const float ramp_mix = rope_yarn_ramp(yarn_low, yarn_high, i0) * ramp_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
} else {
theta = theta_interp;
}

cache[i0 + 0] = cosf(theta) * attn_factor;
cache[i0 + 1] = sinf(theta) * attn_factor * sin_sign;
}
};

for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len

float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
{
const float p = pos[i2];
init_cache(cache, p);
}
// TODO M-RoPE

for (int64_t i1 = idx_offset; i1 < ne1; i1++) { // attn-heads
if (ir++ < ir0) continue;
if (ir > ir1) break;

T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);

rotate_pairs<T>(n_dims, idx_pair, cache, src, dst_data, idx_scale);
// TODO M-RoPE

// fill the remain channels with data from src tensor
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
} // attn-heads
}
}
}

void ggml_compute_forward_rope_comp(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_rope_comp_flt<ggml_fp16_t>(params, dst, true);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_rope_comp_flt<float>(params, dst, true);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_rope_back

void ggml_compute_forward_rope_back(
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * para
void ggml_compute_forward_soft_max(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_soft_max_ext_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope_comp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
case GGML_OP_ROPE:
case GGML_OP_ROPE_COMP:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
Expand Down
17 changes: 17 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,23 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_
return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope_comp(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_ROPE_COMP);

char base[256];
char name[256];

snprintf(base, 256, "kernel_rope_comp_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_IM2COL);

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope_comp (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_RMS_NORM:
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
case GGML_OP_ROPE:
case GGML_OP_ROPE_COMP:
return true;
case GGML_OP_IM2COL:
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
Expand Down
34 changes: 34 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,40 @@ typedef struct {
bool src2;
} ggml_metal_kargs_rope;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t n_dims;
int32_t idx_pair;
int32_t idx_scale;
int32_t idx_offset;
float theta_scale;
float yarn_high;
float yarn_low;
float freq_scale;
float ramp_factor;
float attn_factor;
int32_t sect_0;
int32_t sect_1;
int32_t sect_2;
int32_t sect_3;
bool src2;
} ggml_metal_kargs_rope_comp;

typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
Expand Down
Loading
Loading