Skip to content

Commit 44a73f4

Browse files
Math: Optimise 16-bit matrix multiplication functions.
Improve mat_multiply and mat_multiply_elementwise for 16-bit signed integers by refactoring operations and simplifying handling of Q0 data. Both functions now use incremented pointers in loop expressions for better legibility and potential compiler optimisation. Changes: - Improved x pointer increments in mat_multiply. - Integrating Q0 conditional logic into the main flow. - Clean up the mat_multiply_elementwise loop structure. - Ensure precision with appropriate fractional bit shifts. These modifications aim to improve accuracy and efficiency in fixed-point arithmetic operations. Signed-off-by: Shriram Shastry <malladi.sastry@intel.com>
1 parent 07b762e commit 44a73f4

File tree

1 file changed

+29
-55
lines changed

1 file changed

+29
-55
lines changed

src/math/matrix.c

Lines changed: 29 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,55 +3,36 @@
33
// Copyright(c) 2022 Intel Corporation. All rights reserved.
44
//
55
// Author: Seppo Ingalsuo <seppo.ingalsuo@linux.intel.com>
6+
// Shriram Shastry <malladi.sastry@linux.intel.com>
67

78
#include <sof/math/matrix.h>
89
#include <errno.h>
910
#include <stdint.h>
11+
#include <stdlib.h> /* for malloc and free */
12+
#include <string.h>
13+
#include <stdbool.h>
1014

1115
int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c)
1216
{
1317
int64_t s;
14-
int16_t *x;
15-
int16_t *y;
16-
int16_t *z = c->data;
18+
int16_t *x, *y, *z = c->data;
1719
int i, j, k;
1820
int y_inc = b->columns;
19-
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
21+
const int shift = a->fractions + b->fractions - c->fractions - 1;
2022

2123
if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns)
2224
return -EINVAL;
2325

24-
/* If all data is Q0 */
25-
if (shift_minus_one == -1) {
26-
for (i = 0; i < a->rows; i++) {
27-
for (j = 0; j < b->columns; j++) {
28-
s = 0;
29-
x = a->data + a->columns * i;
30-
y = b->data + j;
31-
for (k = 0; k < b->rows; k++) {
32-
s += (int32_t)(*x) * (*y);
33-
x++;
34-
y += y_inc;
35-
}
36-
*z = (int16_t)s; /* For Q16.0 */
37-
z++;
38-
}
39-
}
40-
41-
return 0;
42-
}
43-
4426
for (i = 0; i < a->rows; i++) {
4527
for (j = 0; j < b->columns; j++) {
46-
s = 0;
47-
x = a->data + a->columns * i;
48-
y = b->data + j;
28+
s = 0; x = a->data + a->columns * i; y = b->data + j;
4929
for (k = 0; k < b->rows; k++) {
50-
s += (int32_t)(*x) * (*y);
51-
x++;
52-
y += y_inc;
30+
s += (int32_t)(*x++) * (*y); y += y_inc;
5331
}
54-
*z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
32+
if (shift == -1)
33+
*z = (int16_t)s;
34+
else
35+
*z = (int16_t)(((s >> shift) + 1) >> 1);
5536
z++;
5637
}
5738
}
@@ -60,36 +41,29 @@ int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_
6041

6142
int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
6243
struct mat_matrix_16b *c)
63-
{ int64_t p;
44+
{
45+
if (a->columns != b->columns || a->rows != b->rows)
46+
return -EINVAL;
47+
48+
const int total_elements = a->rows * a->columns;
49+
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
50+
6451
int16_t *x = a->data;
6552
int16_t *y = b->data;
6653
int16_t *z = c->data;
67-
int i;
68-
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
69-
70-
if (a->columns != b->columns || b->columns != c->columns ||
71-
a->rows != b->rows || b->rows != c->rows) {
72-
return -EINVAL;
73-
}
54+
int64_t p;
7455

75-
/* If all data is Q0 */
7656
if (shift_minus_one == -1) {
77-
for (i = 0; i < a->rows * a->columns; i++) {
78-
*z = *x * *y;
79-
x++;
80-
y++;
81-
z++;
82-
}
57+
// If all data is Q0
58+
for (int i = 0; i < total_elements; i++)
59+
z[i] = x[i] * y[i];
8360

84-
return 0;
85-
}
86-
87-
for (i = 0; i < a->rows * a->columns; i++) {
88-
p = (int32_t)(*x) * *y;
89-
*z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
90-
x++;
91-
y++;
92-
z++;
61+
} else {
62+
// General case
63+
for (int i = 0; i < total_elements; i++) {
64+
p = (int32_t)x[i] * y[i];
65+
z[i] = (int16_t)(((p >> shift_minus_one) + 1) >> 1); // Shift to Qx.y
66+
}
9367
}
9468

9569
return 0;

0 commit comments

Comments
 (0)