|
3 | 3 | // Copyright(c) 2022 Intel Corporation. All rights reserved. |
4 | 4 | // |
5 | 5 | // Author: Seppo Ingalsuo <seppo.ingalsuo@linux.intel.com> |
| 6 | +// Shriram Shastry <malladi.sastry@linux.intel.com> |
6 | 7 |
|
7 | 8 | #include <sof/math/matrix.h> |
8 | 9 | #include <errno.h> |
9 | 10 | #include <stdint.h> |
10 | 11 |
|
11 | 12 | int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) |
12 | 13 | { |
13 | | - int64_t s; |
14 | | - int16_t *x; |
15 | | - int16_t *y; |
16 | | - int16_t *z = c->data; |
17 | | - int i, j, k; |
18 | | - int y_inc = b->columns; |
19 | | - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; |
| 14 | + /* check for NULL pointers */ |
| 15 | + if (!a || !b || !c) |
| 16 | + return -EINVAL; |
20 | 17 |
|
| 18 | + /* check for dimensions compatibility */ |
21 | 19 | if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns) |
22 | 20 | return -EINVAL; |
23 | 21 |
|
24 | | - /* If all data is Q0 */ |
25 | | - if (shift_minus_one == -1) { |
26 | | - for (i = 0; i < a->rows; i++) { |
27 | | - for (j = 0; j < b->columns; j++) { |
28 | | - s = 0; |
29 | | - x = a->data + a->columns * i; |
30 | | - y = b->data + j; |
31 | | - for (k = 0; k < b->rows; k++) { |
32 | | - s += (int32_t)(*x) * (*y); |
33 | | - x++; |
34 | | - y += y_inc; |
35 | | - } |
36 | | - *z = (int16_t)s; /* For Q16.0 */ |
37 | | - z++; |
38 | | - } |
39 | | - } |
| 22 | + int64_t acc; |
| 23 | + int16_t *x, *y, *z = c->data; |
| 24 | + int i, j, k; |
| 25 | + /* Increment for pointer y to jump to the next row in matrix B */ |
| 26 | + int y_inc = b->columns; |
| 27 | + const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; |
| 28 | + /* Check for integer overflows during the shift operation. */ |
| 29 | + if (shift_minus_one < -1 || shift_minus_one > 62) |
| 30 | + return -ERANGE; |
40 | 31 |
|
41 | | - return 0; |
42 | | - } |
| 32 | + const int shift = shift_minus_one + 1; |
| 33 | + |
| 34 | + /* Calculate offset for rounding. The offset is only needed when shift is |
| 35 | + * non-negative (> 0). Since shift = shift_minus_one + 1, the check for non-negative |
| 36 | + * value is (shift_minus_one >= -1) |
| 37 | + */ |
| 38 | + const int64_t offset = (shift_minus_one >= -1) ? (1LL << shift_minus_one) : 0; |
43 | 39 |
|
44 | 40 | for (i = 0; i < a->rows; i++) { |
45 | 41 | for (j = 0; j < b->columns; j++) { |
46 | | - s = 0; |
| 42 | + acc = 0; |
| 43 | + /* Position pointer x at the beginning of the ith row in matrix A */ |
47 | 44 | x = a->data + a->columns * i; |
| 45 | + /* Position pointer y at the beginning of the jth column in matrix B */ |
48 | 46 | y = b->data + j; |
49 | 47 | for (k = 0; k < b->rows; k++) { |
50 | | - s += (int32_t)(*x) * (*y); |
51 | | - x++; |
| 48 | + /* Multiply elements and add to sum */ |
| 49 | + acc += (int64_t)(*x++) * (*y); |
| 50 | + /* Move pointer y to the next element in the column */ |
52 | 51 | y += y_inc; |
53 | 52 | } |
54 | | - *z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ |
55 | | - z++; |
| 53 | + /* If shift == 0, then shift_minus_one == -1, which means the data is Q16.0 |
| 54 | + * Otherwise, add the offset before the shift to round up if necessary |
| 55 | + */ |
| 56 | + *z++ = (shift == 0) ? (int16_t)acc : (int16_t)((acc + offset) >> shift); |
56 | 57 | } |
57 | 58 | } |
| 59 | + |
58 | 60 | return 0; |
59 | 61 | } |
60 | 62 |
|
61 | 63 | int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b, |
62 | 64 | struct mat_matrix_16b *c) |
63 | | -{ int64_t p; |
| 65 | +{ |
| 66 | + /* check for NULL pointers */ |
| 67 | + if (!a || !b || !c) |
| 68 | + return -EINVAL; |
| 69 | + /* Validate dimensions match for elementwise multiplication */ |
| 70 | + if (a->columns != b->columns || a->rows != b->rows) |
| 71 | + return -EINVAL; |
| 72 | + |
| 73 | + const int total_elements = a->rows * a->columns; |
| 74 | + /* Calculate shift for result Q format */ |
| 75 | + const int shift = a->fractions + b->fractions - c->fractions; |
| 76 | + /* Check for integer overflows during the shift operation. */ |
| 77 | + if (shift < -1 || shift > 62) |
| 78 | + return -ERANGE; |
| 79 | + /* Offset for rounding */ |
| 80 | + const int64_t offset = (shift >= 0) ? (1LL << (shift - 1)) : 0; |
| 81 | + |
64 | 82 | int16_t *x = a->data; |
65 | 83 | int16_t *y = b->data; |
66 | 84 | int16_t *z = c->data; |
| 85 | + int64_t acc; |
67 | 86 | int i; |
68 | | - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; |
69 | 87 |
|
70 | | - if (a->columns != b->columns || b->columns != c->columns || |
71 | | - a->rows != b->rows || b->rows != c->rows) { |
72 | | - return -EINVAL; |
73 | | - } |
74 | | - |
75 | | - /* If all data is Q0 */ |
76 | | - if (shift_minus_one == -1) { |
77 | | - for (i = 0; i < a->rows * a->columns; i++) { |
78 | | - *z = *x * *y; |
79 | | - x++; |
80 | | - y++; |
81 | | - z++; |
| 88 | + /* When no shifting is required (shift -1 indicated), simply multiply */ |
| 89 | + if (shift == -1) { |
| 90 | + for (i = 0; i < total_elements; i++) |
| 91 | + z[i] = x[i] * y[i]; /* Direct elementwise multiplication */ |
| 92 | + } else { |
| 93 | + /* General case with shifting and offset for rounding */ |
| 94 | + for (i = 0; i < total_elements; i++) { |
| 95 | + acc = (int64_t)x[i] * (int64_t)y[i]; /* Cast to int64_t to avoid overflow */ |
| 96 | + z[i] = (int16_t)((acc + offset) >> shift); /* Apply shift and offset */ |
82 | 97 | } |
83 | | - |
84 | | - return 0; |
85 | | - } |
86 | | - |
87 | | - for (i = 0; i < a->rows * a->columns; i++) { |
88 | | - p = (int32_t)(*x) * *y; |
89 | | - *z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ |
90 | | - x++; |
91 | | - y++; |
92 | | - z++; |
93 | 98 | } |
94 | 99 |
|
95 | 100 | return 0; |
|
0 commit comments