Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 182 additions & 18 deletions ggml/src/ggml-cpu/vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,27 +414,61 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
int i = 0;
ggml_float sum = 0;
// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 sum_v = _mm512_setzero_ps();
for (; i + 63 < n; i += 64) {
__m512 val1 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 0), _mm512_set1_ps(mean));
__m512 val2 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 16), _mm512_set1_ps(mean));
__m512 val3 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 32), _mm512_set1_ps(mean));
__m512 val4 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 48), _mm512_set1_ps(mean));
_mm512_storeu_ps(y + i + 0, val1);
_mm512_storeu_ps(y + i + 16, val2);
_mm512_storeu_ps(y + i + 32, val3);
_mm512_storeu_ps(y + i + 48, val4);
sum_v = _mm512_fmadd_ps(val1, val1, sum_v);
sum_v = _mm512_fmadd_ps(val2, val2, sum_v);
sum_v = _mm512_fmadd_ps(val3, val3, sum_v);
sum_v = _mm512_fmadd_ps(val4, val4, sum_v);
}
for (; i + 15 < n; i += 16) {
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
_mm512_set1_ps(mean));
_mm512_storeu_ps(y + i, val);
sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
sum_v = _mm512_fmadd_ps(val, val, sum_v);
}
sum += (ggml_float)_mm512_reduce_add_ps(sum_v);
#elif defined(__AVX2__) && defined(__FMA__)
__m256 sum_v1 = _mm256_setzero_ps();
__m256 sum_v2 = _mm256_setzero_ps();
__m256 sum_v3 = _mm256_setzero_ps();
__m256 sum_v4 = _mm256_setzero_ps();
for (; i + 31 < n; i += 32) {
__m256 val1 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 0), _mm256_set1_ps(mean));
__m256 val2 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 8), _mm256_set1_ps(mean));
__m256 val3 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 16), _mm256_set1_ps(mean));
__m256 val4 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 24), _mm256_set1_ps(mean));
_mm256_storeu_ps(y + i + 0, val1);
_mm256_storeu_ps(y + i + 8, val2);
_mm256_storeu_ps(y + i + 16, val3);
_mm256_storeu_ps(y + i + 24, val4);
sum_v1 = _mm256_fmadd_ps(val1, val1, sum_v1);
sum_v2 = _mm256_fmadd_ps(val2, val2, sum_v2);
sum_v3 = _mm256_fmadd_ps(val3, val3, sum_v3);
sum_v4 = _mm256_fmadd_ps(val4, val4, sum_v4);
}
for (; i + 7 < n; i += 8) {
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
_mm256_set1_ps(mean));
_mm256_storeu_ps(y + i, val);
val = _mm256_mul_ps(val,val);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
_mm256_castps256_ps128(val));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
}
sum_v1 = _mm256_fmadd_ps(val, val, sum_v1);
}
sum_v1 = _mm256_add_ps(sum_v1, sum_v2);
sum_v1 = _mm256_add_ps(sum_v1, sum_v3);
sum_v1 = _mm256_add_ps(sum_v1, sum_v4);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(sum_v1, 1), _mm256_castps256_ps128(sum_v1));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
Expand Down Expand Up @@ -491,23 +525,60 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
int i = 0;
ggml_float sum = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 sum_v = _mm512_setzero_ps();
for (; i + 63 < n; i += 64) {
__m512 val1 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 0), _mm512_set1_ps(max)));
__m512 val2 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 16), _mm512_set1_ps(max)));
__m512 val3 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 32), _mm512_set1_ps(max)));
__m512 val4 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 48), _mm512_set1_ps(max)));
_mm512_storeu_ps(y + i + 0, val1);
_mm512_storeu_ps(y + i + 16, val2);
_mm512_storeu_ps(y + i + 32, val3);
_mm512_storeu_ps(y + i + 48, val4);
sum_v = _mm512_add_ps(sum_v, val1);
sum_v = _mm512_add_ps(sum_v, val2);
sum_v = _mm512_add_ps(sum_v, val3);
sum_v = _mm512_add_ps(sum_v, val4);
}
for (; i + 15 < n; i += 16) {
__m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
_mm512_set1_ps(max)));
_mm512_storeu_ps(y + i, val);
sum += (ggml_float)_mm512_reduce_add_ps(val);
sum_v = _mm512_add_ps(sum_v, val);
}
sum += (ggml_float)_mm512_reduce_add_ps(sum_v);
#elif defined(__AVX2__) && defined(__FMA__)
__m256 sum_v1 = _mm256_setzero_ps();
__m256 sum_v2 = _mm256_setzero_ps();
__m256 sum_v3 = _mm256_setzero_ps();
__m256 sum_v4 = _mm256_setzero_ps();
for (; i + 31 < n; i += 32) {
__m256 val1 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 0), _mm256_set1_ps(max)));
__m256 val2 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 8), _mm256_set1_ps(max)));
__m256 val3 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 16), _mm256_set1_ps(max)));
__m256 val4 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 24), _mm256_set1_ps(max)));
_mm256_storeu_ps(y + i + 0, val1);
_mm256_storeu_ps(y + i + 8, val2);
_mm256_storeu_ps(y + i + 16, val3);
_mm256_storeu_ps(y + i + 24, val4);
sum_v1 = _mm256_add_ps(sum_v1, val1);
sum_v2 = _mm256_add_ps(sum_v2, val2);
sum_v3 = _mm256_add_ps(sum_v3, val3);
sum_v4 = _mm256_add_ps(sum_v4, val4);
}
for (; i + 7 < n; i += 8) {
__m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
_mm256_set1_ps(max)));
_mm256_storeu_ps(y + i, val);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
_mm256_castps256_ps128(val));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
}
sum_v1 = _mm256_add_ps(sum_v1, val);
}
sum_v1 = _mm256_add_ps(sum_v1, sum_v2);
sum_v1 = _mm256_add_ps(sum_v1, sum_v3);
sum_v1 = _mm256_add_ps(sum_v1, sum_v4);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(sum_v1, 1), _mm256_castps256_ps128(sum_v1));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
__m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
Expand Down Expand Up @@ -563,10 +634,103 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl

int i = 0;
ggml_float sum = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 sum_v = _mm512_setzero_ps();
for (; i + 63 < n; i += 64) {
__m512 val1 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 0), _mm512_set1_ps(max));
__m512 val2 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 16), _mm512_set1_ps(max));
__m512 val3 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 32), _mm512_set1_ps(max));
__m512 val4 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 48), _mm512_set1_ps(max));
_mm512_storeu_ps(y + i + 0, val1);
_mm512_storeu_ps(y + i + 16, val2);
_mm512_storeu_ps(y + i + 32, val3);
_mm512_storeu_ps(y + i + 48, val4);
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val1));
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val2));
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val3));
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val4));
}
for (; i + 15 < n; i += 16) {
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i), _mm512_set1_ps(max));
_mm512_storeu_ps(y + i, val);
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val));
}
sum = (ggml_float)_mm512_reduce_add_ps(sum_v);
#elif defined(__AVX2__) && defined(__FMA__)
__m256 sum_v1 = _mm256_setzero_ps();
__m256 sum_v2 = _mm256_setzero_ps();
__m256 sum_v3 = _mm256_setzero_ps();
__m256 sum_v4 = _mm256_setzero_ps();
for (; i + 31 < n; i += 32) {
__m256 val1 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 0), _mm256_set1_ps(max));
__m256 val2 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 8), _mm256_set1_ps(max));
__m256 val3 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 16), _mm256_set1_ps(max));
__m256 val4 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 24), _mm256_set1_ps(max));
_mm256_storeu_ps(y + i + 0, val1);
_mm256_storeu_ps(y + i + 8, val2);
_mm256_storeu_ps(y + i + 16, val3);
_mm256_storeu_ps(y + i + 24, val4);
sum_v1 = _mm256_add_ps(sum_v1, ggml_v_expf(val1));
sum_v2 = _mm256_add_ps(sum_v2, ggml_v_expf(val2));
sum_v3 = _mm256_add_ps(sum_v3, ggml_v_expf(val3));
sum_v4 = _mm256_add_ps(sum_v4, ggml_v_expf(val4));
}
for (; i + 7 < n; i += 8) {
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i), _mm256_set1_ps(max));
_mm256_storeu_ps(y + i, val);
sum_v1 = _mm256_add_ps(sum_v1, ggml_v_expf(val));
}
sum_v1 = _mm256_add_ps(sum_v1, sum_v2);
sum_v1 = _mm256_add_ps(sum_v1, sum_v3);
sum_v1 = _mm256_add_ps(sum_v1, sum_v4);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(sum_v1, 1), _mm256_castps256_ps128(sum_v1));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i), _mm_set1_ps(max));
_mm_storeu_ps(y + i, val);
val = ggml_v_expf(val);
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
val = _mm_add_ps(val, _mm_movehl_ps(val, val));
val = _mm_add_ss(val, _mm_movehdup_ps(val));
#else
__m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
val = _mm_add_ps(val, tmp);
tmp = _mm_movehl_ps(tmp, val);
val = _mm_add_ss(val, tmp);
#endif
sum += (ggml_float)_mm_cvtss_f32(val);
}
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
const int vlen = svcntw();
for (; i < n; i += vlen) {
const svbool_t pg = svwhilelt_b32_s32(i, n);
svfloat32_t val = svsub_f32_x(pg, svld1_f32(pg, x + i), svdup_n_f32_x(pg, max));
svst1_f32(pg, y + i, val);
sum += (ggml_float)svaddv_f32(pg, ggml_v_expf(pg, val));
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
float32x4_t val = vsubq_f32(vld1q_f32(x + i), vdupq_n_f32(max));
vst1q_f32(y + i, val);
sum += (ggml_float)vaddvq_f32(ggml_v_expf(val));
}
#elif defined(__riscv_v_intrinsic)
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
for (int avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m2(n - i);
vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl);
__riscv_vse32_v_f32m2(&y[i], val, avl);
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(ggml_v_expf_m2(val, avl), vsum, avl);
}
sum = (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
#endif
for (; i < n; ++i) {
float val = x[i] - max;
y[i] = val;
sum += (ggml_float)expf(val);
}
return sum = (ggml_float)logf(sum);
return (ggml_float)logf(sum);
}
Loading