Skip to content

Commit 0a0bba0

Browse files
authored
ggml-hexagon: swiglu_oai operation (#18114)
* snapshot: debug ggml-hexagon swiglu-oai * fix: fix hvx_min_scalar_f32 * feat: working swiglu-oai * chore: fix formating isue
1 parent 5166aaf commit 0a0bba0

File tree

4 files changed

+72
-37
lines changed

4 files changed

+72
-37
lines changed

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3312,7 +3312,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
33123312
break;
33133313

33143314
case GGML_OP_GLU:
3315-
if ((ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU) /* || (ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU_OAI) */) {
3315+
if ((ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU) || (ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU_OAI) ) {
33163316
supp = ggml_hexagon_supported_activations(sess, op);
33173317
}
33183318
break;

ggml/src/ggml-hexagon/htp/act-ops.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
231231
// x (src0_spad_data) = std::min(src0_p[k], limit);
232232
hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
233233
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
234-
hvx_clamp_scalar_f32((const uint8_t *) src1, limit, limit, src1_spad_data, nc);
234+
hvx_clamp_scalar_f32((const uint8_t *) src1, -limit, limit, src1_spad_data, nc);
235235
// y (src1_spad_data) = y1 + 1.f
236236
hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc);
237237
// x1 (dst_spad_data) = alpha * (x)

ggml/src/ggml-hexagon/htp/hvx-utils.c

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -948,35 +948,45 @@ float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
948948
void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
949949
size_t left_over = num_elems & (VLEN_FP32 - 1);
950950
size_t num_elems_whole = num_elems - left_over;
951-
951+
int unalign_address = 0;
952952
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
953953
FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
954+
unalign_address = 1;
954955
}
955956

956-
assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
957-
958957
const float * src_f = (const float *) src;
959958

960-
HVX_Vector vec_min = Q6_V_vsplat_R(val);
959+
HVX_Vector vec_min = hvx_vec_splat_fp32(val);
961960

962-
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
963-
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
961+
if(unalign_address == 0){
962+
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
963+
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
964964

965-
#pragma unroll(4)
966-
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
967-
vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
968-
*vec_out++ = Q6_Vsf_equals_Vqf32(vec_min);
965+
#pragma unroll(4)
966+
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
967+
HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
968+
*vec_out++ = (min_clamp);
969+
}
970+
}else{
971+
HVX_UVector * restrict vec_in = (HVX_Vector *) src;
972+
HVX_UVector * restrict vec_out = (HVX_Vector *) dst;
973+
974+
#pragma unroll(4)
975+
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
976+
HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
977+
*vec_out++ = (min_clamp);
978+
}
969979
}
970980

971-
if (left_over > 0) {
981+
if (left_over > 0 ) {
972982
const float * srcf = (const float *) src + num_elems_whole;
973983
float * dstf = (float *) dst + num_elems_whole;
974984

975-
HVX_Vector in = *(HVX_UVector *) srcf;
985+
HVX_UVector in = *(HVX_UVector *) srcf;
976986

977-
vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, in);
987+
HVX_UVector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, in);
978988

979-
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(vec_min));
989+
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, (min_clamp));
980990
}
981991
}
982992

@@ -988,46 +998,70 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src,
988998
size_t left_over = num_elems & (VLEN_FP32 - 1);
989999
size_t num_elems_whole = num_elems - left_over;
9901000

1001+
int unalign_address = 0;
9911002
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
9921003
FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
1004+
unalign_address = 1;
9931005
}
9941006

995-
assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
996-
997-
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
998-
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
999-
10001007
HVX_Vector range_left = hvx_vec_splat_fp32(limit_left);
10011008
HVX_Vector range_right = hvx_vec_splat_fp32(limit_right);
10021009

1003-
#pragma unroll(4)
1004-
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
1005-
HVX_Vector in_vec = *vec_in++;
1006-
HVX_Vector temp_v = in_vec;
1010+
if(unalign_address == 0){
1011+
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
1012+
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
10071013

1008-
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
1009-
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
10101014

1011-
in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
1012-
in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
10131015

1014-
*vec_out++ = Q6_Vsf_equals_Vqf32(in_vec);
1016+
#pragma unroll(4)
1017+
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
1018+
HVX_Vector in_vec = *vec_in++;
1019+
HVX_Vector temp_v = in_vec;
1020+
1021+
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
1022+
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
1023+
1024+
in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
1025+
in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
1026+
1027+
*vec_out++ = in_vec;
1028+
}
1029+
1030+
}else{
1031+
1032+
HVX_UVector * restrict vec_in = (HVX_UVector *) src;
1033+
HVX_UVector * restrict vec_out = (HVX_UVector *) dst;
1034+
1035+
#pragma unroll(4)
1036+
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
1037+
HVX_Vector in_vec = *vec_in++;
1038+
HVX_Vector temp_v = in_vec;
1039+
1040+
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
1041+
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
1042+
1043+
in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
1044+
in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
1045+
1046+
*vec_out++ = in_vec;
1047+
}
1048+
10151049
}
10161050

10171051
if (left_over > 0) {
10181052
const float * srcf = (const float *) src + num_elems_whole;
10191053
float * dstf = (float *) dst + num_elems_whole;
10201054

1021-
HVX_Vector in = *(HVX_UVector *) srcf;
1055+
HVX_Vector in_vec = *(HVX_UVector *) srcf;
10221056

1023-
HVX_Vector temp_v = in;
1057+
HVX_Vector temp_v = in_vec;
10241058

1025-
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in, range_right);
1026-
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in);
1059+
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
1060+
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
10271061

1028-
in = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
1029-
in = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
1062+
in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
1063+
in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
10301064

1031-
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(in));
1065+
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
10321066
}
10331067
}

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
807807
break;
808808

809809
case HTP_OP_GLU_SWIGLU:
810+
case HTP_OP_GLU_SWIGLU_OAI:
810811
case HTP_OP_SOFTMAX:
811812
if ((n_bufs != 2) && (n_bufs != 3)) {
812813
FARF(ERROR, "Bad act-req buffer list");

0 commit comments

Comments
 (0)