Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,27 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_me
return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(!op->src[0] || op->src[0]->type == GGML_TYPE_F32);

char base[256];
char name[256];

const ggml_type tsrc1 = GGML_TYPE_F32;

snprintf(base, 256, "kernel_cross_entropy_loss_%s", ggml_type_name(tsrc1));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}

res.smem = 32*sizeof(float);

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_bl
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
11 changes: 11 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,17 @@ typedef struct {
int64_t np;
} ggml_metal_kargs_pool_2d;

typedef struct {
int32_t ne00;
int32_t nrows;
int32_t k;
} ggml_metal_kargs_cross_entropy_loss;

typedef struct {
int32_t ne00;
int32_t nrows;
} ggml_metal_kargs_cross_entropy_loss_back;

typedef struct {
int64_t ne00;
uint64_t nb01;
Expand Down
40 changes: 40 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,46 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
return 1;
}

int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx){
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const ggml_tensor * src0 = op->src[0]; // NOTE: logits
const ggml_tensor * src1 = op->src[1]; // NOTE: labels

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));

const int32_t ne00 = src0->ne[0];
const int32_t nrows = ggml_nrows(src1);
ggml_metal_kargs_cross_entropy_loss args = {
/*int32_t*/ ne00,
/*int32_t*/ nrows,
/*int32_t*/ nrows,
};
int nth = 32;
auto pipeline = ggml_metal_library_get_pipeline_cross_entropy(lib, op);

const size_t smem = pipeline.smem;

ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
if (op->src[1]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2);
}

ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4);

ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);

ggml_metal_encoder_dispatch_threadgroups(enc, ne00, nrows, nrows, nth, 1, 1);
return 1;
}

int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
Expand Down
140 changes: 140 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2280,6 +2280,146 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;

template <typename T>
kernel void kernel_cross_entropy_loss(
constant ggml_metal_kargs_cross_entropy_loss & args,
device const float * logits,
device const float * labels,
device float * dst,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint tptg[[threads_per_threadgroup]]) {

const int row = tgpig;
device const float * logits_row = logits + row * args.ne00;
device const float * labels_row = labels + row * args.ne00;

float lmax = - INFINITY;
for (int i = tpitg; i < args.ne00; i+= tptg){
lmax = MAX(lmax, logits_row[i]);
}
float max_val = simd_max(lmax);
if (tptg > N_SIMDWIDTH) {
if (sgitg == 0) buf[tiisg] = -INFINITY;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) buf[sgitg] = max_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}

float lsum = 0.0f;
for (int i = tpitg; i < args.ne00; i += tptg){
const float exp_val = exp(logits_row[i] - max_val);
lsum += exp_val;
dst[i] = exp_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float sum = simd_sum(lsum);
if (tptg > N_SIMDWIDTH){
if (sgitg == 0) buf[tiisg] = 0.0f;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) buf[sgitg] = sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
const float log_sum = log(sum);
float lloss = 0.0f;
for (int i = tpitg; i < args.ne00; i += tptg){
const float log_softmax_i = logits_row[i] - max_val - log_sum;
lloss += log_softmax_i * logits_row[i];
}

float loss = simd_sum(lloss);
if (tptg > N_SIMDWIDTH) {
if (sgitg == 0) buf[tiisg] = 0.0f;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) buf[sgitg] = loss;
threadgroup_barrier(mem_flags::mem_threadgroup);
loss = buf[tiisg];
loss = simd_sum(loss);
}
if (tpitg == 0) {
dst[row] = -loss / args.nrows;
}
}


template <typename T>
kernel void kernel_cross_entropy_loss_back(
constant ggml_metal_kargs_cross_entropy_loss_back & args,
device const float * grad,
device const float * logits, // src0
device const float * labels, // src1
device float * dst, // output
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint tptg[[threads_per_threadgroup]]){

const int row = tpitg;

device const float * logits_row = logits + row * args.ne00;
device const float * labels_row = labels + row * args.ne00;
device float * dst_row = dst + row * args.ne00;

// find max

float lmax = - INFINITY;
for (int i = tpitg; i < args.ne00; i+= tptg){
lmax = MAX(lmax, logits_row[i]);
}
float max_val = simd_max(lmax);
if (tptg > N_SIMDWIDTH) {
if (sgitg == 0) buf[tiisg] = -INFINITY;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) buf[sgitg] = max_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}

float lsum = 0.0f;
for (int i = tpitg; i < args.ne00; i += tptg){
const float exp_val = exp(logits_row[i] - max_val);
lsum += exp_val;
dst_row[i] = exp_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float sum = simd_sum(lsum);
if (tptg > N_SIMDWIDTH){
if (sgitg == 0) buf[tiisg] = 0.0f;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) buf[sgitg] = sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
const float inv_sum = 1.0f / sum;
const float d_by_nrows = grad[0] / args.nrows;

for (int i = tpitg; i < args.ne00; i += tptg){
const float softmax_i = dst_row[i] * inv_sum; // exp(logits - max)/ sum(exp(logits - val))
dst_row[i] = (softmax_i - labels_row[i]) * d_by_nrows;
}

}

typedef decltype(kernel_cross_entropy_loss<float>) kernel_cross_entropy_loss_t;
typedef decltype(kernel_cross_entropy_loss_back<float>) kernel_cross_entropy_loss_back_t;

template [[host_name("kernel_cross_entropy_loss_f32")]]
kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<float>;

template [[host_name("kernel_cross_entropy_loss_back_f32")]]
kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back<float>;

// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
kernel void kernel_ssm_conv_f32_f32(
constant ggml_metal_kargs_ssm_conv & args,
Expand Down
Loading