diff --git a/kernel/riscv64/sbgemm_kernel_16x8_zvl256b.c b/kernel/riscv64/sbgemm_kernel_16x8_zvl256b.c index 6e7b06884d..f83abb2339 100644 --- a/kernel/riscv64/sbgemm_kernel_16x8_zvl256b.c +++ b/kernel/riscv64/sbgemm_kernel_16x8_zvl256b.c @@ -1,22 +1,90 @@ #include "common.h" #include +#define BF16_WIDEN_ONE // Convert pre-hand and do operations in FP32 +#define USE_BF16_CVT // Comment out for pre-RVA23 systems like BananaPi + +#ifdef BF16_WIDEN_ONE +#define FORCEINLINE inline __attribute__((always_inline)) +#define B_UNROLL 64 + +// Convert from BF16 to FP32 +static void FORCEINLINE B_CONV(__bf16 *BB, FLOAT *CONV, BLASLONG count) +{ + BLASLONG count2 = (count & (B_UNROLL - 1)); + count &= -B_UNROLL; + while (count) { + vbfloat16m4_t B00 = __riscv_vle16_v_bf16m4(BB, B_UNROLL); +#ifdef USE_BF16_CVT + vfloat32m8_t B0 = __riscv_vfwcvtbf16_f_f_v_f32m8(B00, B_UNROLL); +#else + vfloat32m8_t B0 = __riscv_vreinterpret_v_u32m8_f32m8(__riscv_vsll_vx_u32m8( + __riscv_vwcvtu_x_x_v_u32m8(__riscv_vreinterpret_v_bf16m4_u16m4(B00), B_UNROLL), 16, B_UNROLL)); +#endif + __riscv_vse32_v_f32m8(CONV, B0, B_UNROLL); + BB += B_UNROLL; + CONV += B_UNROLL; + count -= B_UNROLL; + } + if (count2) { + BLASLONG gvl2 = __riscv_vsetvl_e16m4(count2); + vbfloat16m4_t B00 = __riscv_vle16_v_bf16m4(BB, gvl2); +#ifdef USE_BF16_CVT + vfloat32m8_t B0 = __riscv_vfwcvtbf16_f_f_v_f32m8(B00, gvl2); +#else + vfloat32m8_t B0 = __riscv_vreinterpret_v_u32m8_f32m8(__riscv_vsll_vx_u32m8( + __riscv_vwcvtu_x_x_v_u32m8(__riscv_vreinterpret_v_bf16m4_u16m4(B00), gvl2), 16, gvl2)); +#endif + __riscv_vse32_v_f32m8(CONV, B0, gvl2); + } +} +#endif + +#ifndef VECTORIZE_MEMSET +#define memset_zero(ptr, size, dir) memset(ptr, 0, size) +#else +void memset_zero(void *input, BLASLONG size, bool dir); +#endif + int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) { BLASLONG gvl = 0; BLASLONG m_top = 0; BLASLONG n_top = 0; +#if !defined(BF16_WIDEN_ONE) || !defined(BF16_DONT_CONV) __bf16 *BB = (__bf16 *)(B); __bf16 *AA = (__bf16 *)(A); +#endif + +#ifdef BF16_WIDEN_ONE + FLOAT *CONV = (FLOAT *)(malloc((K * (8 + M)) * sizeof(FLOAT))); + if (!CONV) return 1; +#ifndef BF16_DONT_CONV + B_CONV(AA, CONV + (K * 8), K * M); +#else + memset_zero(CONV, (K * (8 + M)) * sizeof(FLOAT), false); +#endif +#endif // -- MAIN PASS for (BLASLONG j=0; j +#define BF16_WIDEN_ONE // Convert pre-hand and do operations in FP32 +#define USE_BF16_CVT // Comment out for pre-RVA23 systems + +#ifdef BF16_WIDEN_ONE +#define FORCEINLINE inline __attribute__((always_inline)) +#define B_UNROLL 32 + +// Convert from BF16 to FP32 +static void FORCEINLINE B_CONV(__bf16 *BB, FLOAT *CONV, BLASLONG count) +{ + BLASLONG count2 = (count & (B_UNROLL - 1)); + count &= -B_UNROLL; + while (count) { + vbfloat16m4_t B00 = __riscv_vle16_v_bf16m4(BB, B_UNROLL); +#ifdef USE_BF16_CVT + vfloat32m8_t B0 = __riscv_vfwcvtbf16_f_f_v_f32m8(B00, B_UNROLL); +#else + vfloat32m8_t B0 = __riscv_vreinterpret_v_u32m8_f32m8(__riscv_vsll_vx_u32m8( + __riscv_vwcvtu_x_x_v_u32m8(__riscv_vreinterpret_v_bf16m4_u16m4(B00), B_UNROLL), 16, B_UNROLL)); +#endif + __riscv_vse32_v_f32m8(CONV, B0, B_UNROLL); + BB += B_UNROLL; + CONV += B_UNROLL; + count -= B_UNROLL; + } + if (count2) { + BLASLONG gvl2 = __riscv_vsetvl_e16m4(count2); + vbfloat16m4_t B00 = __riscv_vle16_v_bf16m4(BB, gvl2); +#ifdef USE_BF16_CVT + vfloat32m8_t B0 = __riscv_vfwcvtbf16_f_f_v_f32m8(B00, gvl2); +#else + vfloat32m8_t B0 = __riscv_vreinterpret_v_u32m8_f32m8(__riscv_vsll_vx_u32m8( + __riscv_vwcvtu_x_x_v_u32m8(__riscv_vreinterpret_v_bf16m4_u16m4(B00), gvl2), 16, gvl2)); +#endif + __riscv_vse32_v_f32m8(CONV, B0, gvl2); + } +} +#endif + +#ifndef VECTORIZE_MEMSET +#define memset_zero(ptr, size, dir) memset(ptr, 0, size) +#else +void memset_zero(void *input, BLASLONG size, bool dir); +#endif + int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) { BLASLONG gvl = 0; BLASLONG m_top = 0; BLASLONG n_top = 0; +#if !defined(BF16_WIDEN_ONE) || !defined(BF16_DONT_CONV) __bf16 *BB = (__bf16 *)(B); __bf16 *AA = (__bf16 *)(A); +#endif + +#ifdef BF16_WIDEN_ONE + FLOAT *CONV = (FLOAT *)(malloc((K * (8 + M)) * sizeof(FLOAT))); + if (!CONV) return 1; +#ifndef BF16_DONT_CONV + B_CONV(AA, CONV + (K * 8), K * M); +#else + memset_zero(CONV, (K * (8 + M)) * sizeof(FLOAT), false); +#endif +#endif // -- MAIN PASS for (BLASLONG j=0; j + +#define FP16_NARROW // Accumulate in FP16 + int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) { BLASLONG gvl = 0; BLASLONG m_top = 0; BLASLONG n_top = 0; +#ifdef FP16_NARROW + IFLOAT alpha16 = (IFLOAT)(alpha); +#endif // -- MAIN PASS for (BLASLONG j=0; j +#define FP16_NARROW // Accumulate in FP16 + int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) { BLASLONG gvl = 0; BLASLONG m_top = 0; BLASLONG n_top = 0; +#ifdef FP16_NARROW + IFLOAT alpha16 = (IFLOAT)(alpha); +#endif // -- MAIN PASS for (BLASLONG j=0; j