Skip to content

Commit b3fa07c

Browse files
Math: Optimise 16-bit matrix multiplication functions.
Improve mat_multiply and mat_multiply_elementwise for 16-bit signed integers by refactoring operations and simplifying handling of Q0 data. Both functions now use incremented pointers in loop expressions for better legibility and potential compiler optimisation. Changes: - Improved x pointer increments in mat_multiply. - Integrating Q0 conditional logic into the main flow. - Clean up the mat_multiply_elementwise loop structure. - Ensure precision with appropriate fractional bit shifts. These modifications aim to improve accuracy and efficiency in fixed-point arithmetic operations. Signed-off-by: Shriram Shastry <malladi.sastry@intel.com>
1 parent 07b762e commit b3fa07c

File tree

1 file changed

+26
-54
lines changed

1 file changed

+26
-54
lines changed

src/math/matrix.c

Lines changed: 26 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// Copyright(c) 2022 Intel Corporation. All rights reserved.
44
//
55
// Author: Seppo Ingalsuo <seppo.ingalsuo@linux.intel.com>
6+
// Shriram Shastry <malladi.sastry@linux.intel.com>
67

78
#include <sof/math/matrix.h>
89
#include <errno.h>
@@ -11,47 +12,24 @@
1112
int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c)
1213
{
1314
int64_t s;
14-
int16_t *x;
15-
int16_t *y;
16-
int16_t *z = c->data;
15+
int16_t *x, *y, *z = c->data;
1716
int i, j, k;
1817
int y_inc = b->columns;
19-
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
18+
const int shift = a->fractions + b->fractions - c->fractions - 1;
2019

2120
if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns)
2221
return -EINVAL;
2322

24-
/* If all data is Q0 */
25-
if (shift_minus_one == -1) {
26-
for (i = 0; i < a->rows; i++) {
27-
for (j = 0; j < b->columns; j++) {
28-
s = 0;
29-
x = a->data + a->columns * i;
30-
y = b->data + j;
31-
for (k = 0; k < b->rows; k++) {
32-
s += (int32_t)(*x) * (*y);
33-
x++;
34-
y += y_inc;
35-
}
36-
*z = (int16_t)s; /* For Q16.0 */
37-
z++;
38-
}
39-
}
40-
41-
return 0;
42-
}
43-
4423
for (i = 0; i < a->rows; i++) {
4524
for (j = 0; j < b->columns; j++) {
46-
s = 0;
47-
x = a->data + a->columns * i;
48-
y = b->data + j;
25+
s = 0; x = a->data + a->columns * i; y = b->data + j;
4926
for (k = 0; k < b->rows; k++) {
50-
s += (int32_t)(*x) * (*y);
51-
x++;
52-
y += y_inc;
27+
s += (int32_t)(*x++) * (*y); y += y_inc;
5328
}
54-
*z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
29+
if (shift == -1)
30+
*z = (int16_t)s;
31+
else
32+
*z = (int16_t)(((s >> shift) + 1) >> 1);
5533
z++;
5634
}
5735
}
@@ -60,36 +38,30 @@ int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_
6038

6139
int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
6240
struct mat_matrix_16b *c)
63-
{ int64_t p;
41+
{
42+
if (a->columns != b->columns || a->rows != b->rows)
43+
return -EINVAL;
44+
45+
const int total_elements = a->rows * a->columns;
46+
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
47+
6448
int16_t *x = a->data;
6549
int16_t *y = b->data;
6650
int16_t *z = c->data;
51+
int64_t p;
6752
int i;
68-
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
69-
70-
if (a->columns != b->columns || b->columns != c->columns ||
71-
a->rows != b->rows || b->rows != c->rows) {
72-
return -EINVAL;
73-
}
7453

75-
/* If all data is Q0 */
7654
if (shift_minus_one == -1) {
77-
for (i = 0; i < a->rows * a->columns; i++) {
78-
*z = *x * *y;
79-
x++;
80-
y++;
81-
z++;
82-
}
55+
// If all data is Q0
56+
for (i = 0; i < total_elements; i++)
57+
z[i] = x[i] * y[i];
8358

84-
return 0;
85-
}
86-
87-
for (i = 0; i < a->rows * a->columns; i++) {
88-
p = (int32_t)(*x) * *y;
89-
*z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
90-
x++;
91-
y++;
92-
z++;
59+
} else {
60+
// General case
61+
for (i = 0; i < total_elements; i++) {
62+
p = (int32_t)x[i] * y[i];
63+
z[i] = (int16_t)(((p >> shift_minus_one) + 1) >> 1); // Shift to Qx.y
64+
}
9365
}
9466

9567
return 0;

0 commit comments

Comments
 (0)