@@ -970,6 +970,7 @@ class GemmBatchFunctorThreadNM_vecm
970970 size_t i = block_i * wg_delta_n * wi_delta_n;
971971 size_t j = block_j * wg_delta_m * wi_total_delta_m;
972972
973+ using slmA_t = typename LocAccT1::value_type;
973974 using slmB_t = typename LocAccT2::value_type;
974975
975976 const size_t a_st0 = k;
@@ -1057,16 +1058,29 @@ class GemmBatchFunctorThreadNM_vecm
10571058 const std::uint32_t lo_lhs_st_k = (wg_delta_n * wi_delta_n);
10581059 const std::uint32_t lo_rhs_rk_k = (wg_delta_m * wi_delta_m_vecs);
10591060 for (std::uint32_t pr_k = 0 ; pr_k < wi_delta_k; ++pr_k) {
1061+ std::array<slmA_t, wi_delta_n> pr_lhs{};
1062+ #pragma unroll
1063+ for (std::uint32_t pr_i = 0 ; pr_i < wi_delta_n; ++pr_i) {
1064+ pr_lhs[pr_i] =
1065+ local_lhs_block[pr_k * lo_lhs_st_k +
1066+ (local_i + pr_i * wg_delta_n)];
1067+ }
1068+
1069+ std::array<slmB_t, wi_delta_m_vecs> pr_rhs{};
1070+ #pragma unroll
1071+ for (std::uint32_t pr_j = 0 ; pr_j < wi_delta_m_vecs; ++pr_j) {
1072+ pr_rhs[pr_j] =
1073+ local_rhs_block[pr_k * lo_rhs_rk_k +
1074+ (local_j + pr_j * wg_delta_m)];
1075+ }
1076+
10601077#pragma unroll
10611078 for (std::uint32_t pr_i = 0 ; pr_i < wi_delta_n; ++pr_i) {
10621079#pragma unroll
10631080 for (std::uint32_t pr_j = 0 ; pr_j < wi_delta_m_vecs; ++pr_j)
10641081 {
10651082 private_C[pr_i * wi_delta_m_vecs + pr_j] +=
1066- local_lhs_block[pr_k * lo_lhs_st_k +
1067- (local_i + pr_i * wg_delta_n)] *
1068- local_rhs_block[pr_k * lo_rhs_rk_k +
1069- (local_j + pr_j * wg_delta_m)];
1083+ pr_lhs[pr_i] * pr_rhs[pr_j];
10701084 }
10711085 }
10721086 }
0 commit comments