Skip to content

Commit 5493610

Browse files
Math: Optimise 16-bit matrix multiplication functions.
Improve mat_multiply and mat_multiply_elementwise for 16-bit signed integers by refactoring operations and simplifying handling of Q0 data. Both functions now use incremented pointers in loop expressions for better legibility and potential compiler optimisation. Changes: - Improved x pointer increments in mat_multiply. - Integrating Q0 conditional logic into the main flow. - Clean up the mat_multiply_elementwise loop structure. - Ensure precision with appropriate fractional bit shifts. These modifications aim to improve accuracy and efficiency in fixed-point arithmetic operations. Signed-off-by: Shriram Shastry <malladi.sastry@intel.com>
1 parent 07b762e commit 5493610

File tree

1 file changed

+59
-54
lines changed

1 file changed

+59
-54
lines changed

src/math/matrix.c

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,93 +3,98 @@
33
// Copyright(c) 2022 Intel Corporation. All rights reserved.
44
//
55
// Author: Seppo Ingalsuo <seppo.ingalsuo@linux.intel.com>
6+
// Shriram Shastry <malladi.sastry@linux.intel.com>
67

78
#include <sof/math/matrix.h>
89
#include <errno.h>
910
#include <stdint.h>
1011

1112
int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c)
1213
{
13-
int64_t s;
14-
int16_t *x;
15-
int16_t *y;
16-
int16_t *z = c->data;
17-
int i, j, k;
18-
int y_inc = b->columns;
19-
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
14+
/* check for NULL pointers */
15+
if (!a || !b || !c)
16+
return -EINVAL;
2017

18+
/* check for dimensions compatibility */
2119
if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns)
2220
return -EINVAL;
2321

24-
/* If all data is Q0 */
25-
if (shift_minus_one == -1) {
26-
for (i = 0; i < a->rows; i++) {
27-
for (j = 0; j < b->columns; j++) {
28-
s = 0;
29-
x = a->data + a->columns * i;
30-
y = b->data + j;
31-
for (k = 0; k < b->rows; k++) {
32-
s += (int32_t)(*x) * (*y);
33-
x++;
34-
y += y_inc;
35-
}
36-
*z = (int16_t)s; /* For Q16.0 */
37-
z++;
38-
}
39-
}
22+
int64_t acc;
23+
int16_t *x, *y, *z = c->data;
24+
int i, j, k;
25+
/* Increment for pointer y to jump to the next row in matrix B */
26+
int y_inc = b->columns;
27+
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
28+
/* Check for integer overflows during the shift operation. */
29+
if (shift_minus_one < -1 || shift_minus_one > 62)
30+
return -ERANGE;
4031

41-
return 0;
42-
}
32+
const int shift = shift_minus_one + 1;
33+
34+
/* Calculate offset for rounding. The offset is only needed when shift is
35+
* non-negative (> 0). Since shift = shift_minus_one + 1, the check for non-negative
36+
* value is (shift_minus_one >= -1)
37+
*/
38+
const int64_t offset = (shift_minus_one >= -1) ? (1LL << shift_minus_one) : 0;
4339

4440
for (i = 0; i < a->rows; i++) {
4541
for (j = 0; j < b->columns; j++) {
46-
s = 0;
42+
acc = 0;
43+
/* Position pointer x at the beginning of the ith row in matrix A */
4744
x = a->data + a->columns * i;
45+
/* Position pointer y at the beginning of the jth column in matrix B */
4846
y = b->data + j;
4947
for (k = 0; k < b->rows; k++) {
50-
s += (int32_t)(*x) * (*y);
51-
x++;
48+
/* Multiply elements and add to sum */
49+
acc += (int64_t)(*x++) * (*y);
50+
/* Move pointer y to the next element in the column */
5251
y += y_inc;
5352
}
54-
*z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
55-
z++;
53+
/* If shift == 0, then shift_minus_one == -1, which means the data is Q16.0
54+
* Otherwise, add the offset before the shift to round up if necessary
55+
*/
56+
*z++ = (shift == 0) ? (int16_t)acc : (int16_t)((acc + offset) >> shift);
5657
}
5758
}
59+
5860
return 0;
5961
}
6062

6163
int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
6264
struct mat_matrix_16b *c)
63-
{ int64_t p;
65+
{
66+
/* check for NULL pointers */
67+
if (!a || !b || !c)
68+
return -EINVAL;
69+
/* Validate dimensions match for elementwise multiplication */
70+
if (a->columns != b->columns || a->rows != b->rows)
71+
return -EINVAL;
72+
73+
const int total_elements = a->rows * a->columns;
74+
/* Calculate shift for result Q format */
75+
const int shift = a->fractions + b->fractions - c->fractions;
76+
/* Check for integer overflows during the shift operation. */
77+
if (shift < -1 || shift > 62)
78+
return -ERANGE;
79+
/* Offset for rounding */
80+
const int64_t offset = (shift >= 0) ? (1LL << (shift - 1)) : 0;
81+
6482
int16_t *x = a->data;
6583
int16_t *y = b->data;
6684
int16_t *z = c->data;
85+
int64_t acc;
6786
int i;
68-
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
6987

70-
if (a->columns != b->columns || b->columns != c->columns ||
71-
a->rows != b->rows || b->rows != c->rows) {
72-
return -EINVAL;
73-
}
74-
75-
/* If all data is Q0 */
76-
if (shift_minus_one == -1) {
77-
for (i = 0; i < a->rows * a->columns; i++) {
78-
*z = *x * *y;
79-
x++;
80-
y++;
81-
z++;
88+
/* When no shifting is required (shift -1 indicated), simply multiply */
89+
if (shift == -1) {
90+
for (i = 0; i < total_elements; i++)
91+
z[i] = x[i] * y[i]; /* Direct elementwise multiplication */
92+
} else {
93+
/* General case with shifting and offset for rounding */
94+
for (i = 0; i < total_elements; i++) {
95+
acc = (int64_t)x[i] * (int64_t)y[i]; /* Cast to int64_t to avoid overflow */
96+
z[i] = (int16_t)((acc + offset) >> shift); /* Apply shift and offset */
8297
}
83-
84-
return 0;
85-
}
86-
87-
for (i = 0; i < a->rows * a->columns; i++) {
88-
p = (int32_t)(*x) * *y;
89-
*z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
90-
x++;
91-
y++;
92-
z++;
9398
}
9499

95100
return 0;

0 commit comments

Comments
 (0)