|
1 | 1 | """ |
2 | | -Strassen's Matrix Multiplication Algorithm (Descriptive Version) |
3 | | ---------------------------------------------------------------- |
| 2 | +Strassen's Matrix Multiplication Algorithm |
| 3 | +------------------------------------------ |
4 | 4 | An optimized divide-and-conquer algorithm for matrix multiplication that |
5 | 5 | reduces the number of multiplications from 8 (in the naive approach) |
6 | 6 | to 7 per recursion step. |
|
15 | 15 | Matrix = list[list[int]] |
16 | 16 |
|
17 | 17 |
|
18 | | -def add_matrices(matrix_a: Matrix, matrix_b: Matrix) -> Matrix: |
| 18 | +def add(matrix_a: Matrix, matrix_b: Matrix) -> Matrix: |
19 | 19 | """ |
20 | 20 | Add two square matrices of the same size. |
| 21 | +
|
| 22 | + >>> add([[1,2],[3,4]], [[5,6],[7,8]]) |
| 23 | + [[6, 8], [10, 12]] |
21 | 24 | """ |
22 | | - size = len(matrix_a) |
23 | | - return [[matrix_a[i][j] + matrix_b[i][j] for j in range(size)] for i in range(size)] |
| 25 | + n = len(matrix_a) |
| 26 | + return [[matrix_a[i][j] + matrix_b[i][j] for j in range(n)] for i in range(n)] |
24 | 27 |
|
25 | 28 |
|
26 | | -def subtract_matrices(matrix_a: Matrix, matrix_b: Matrix) -> Matrix: |
| 29 | +def sub(matrix_a: Matrix, matrix_b: Matrix) -> Matrix: |
27 | 30 | """ |
28 | 31 | Subtract matrix_b from matrix_a. |
| 32 | +
|
| 33 | + >>> sub([[5,6],[7,8]], [[1,2],[3,4]]) |
| 34 | + [[4, 4], [4, 4]] |
29 | 35 | """ |
30 | | - size = len(matrix_a) |
31 | | - return [[matrix_a[i][j] - matrix_b[i][j] for j in range(size)] for i in range(size)] |
| 36 | + n = len(matrix_a) |
| 37 | + return [[matrix_a[i][j] - matrix_b[i][j] for j in range(n)] for i in range(n)] |
32 | 38 |
|
33 | 39 |
|
34 | | -def multiply_matrices_naive(matrix_a: Matrix, matrix_b: Matrix) -> Matrix: |
| 40 | +def naive_mul(matrix_a: Matrix, matrix_b: Matrix) -> Matrix: |
35 | 41 | """ |
36 | 42 | Multiply two square matrices using the naive O(n^3) method. |
37 | | - """ |
38 | | - size = len(matrix_a) |
39 | | - result_matrix = [[0] * size for _ in range(size)] |
40 | 43 |
|
41 | | - for i in range(size): |
42 | | - for k in range(size): |
43 | | - for j in range(size): |
44 | | - result_matrix[i][j] += matrix_a[i][k] * matrix_b[k][j] |
45 | | - return result_matrix |
| 44 | + >>> naive_mul([[1,2],[3,4]], [[5,6],[7,8]]) |
| 45 | + [[19, 22], [43, 50]] |
| 46 | + """ |
| 47 | + n = len(matrix_a) |
| 48 | + result = [[0] * n for _ in range(n)] |
| 49 | + for i in range(n): |
| 50 | + row_a = matrix_a[i] |
| 51 | + row_result = result[i] |
| 52 | + for k in range(n): |
| 53 | + a_ik = row_a[k] |
| 54 | + col_b = matrix_b[k] |
| 55 | + for j in range(n): |
| 56 | + row_result[j] += a_ik * col_b[j] |
| 57 | + return result |
46 | 58 |
|
47 | 59 |
|
48 | | -def get_next_power_of_two(n: int) -> int: |
| 60 | +def next_power_of_two(n: int) -> int: |
49 | 61 | """ |
50 | 62 | Return the next power of two greater than or equal to n. |
| 63 | +
|
| 64 | + >>> next_power_of_two(5) |
| 65 | + 8 |
51 | 66 | """ |
52 | 67 | power = 1 |
53 | 68 | while power < n: |
54 | 69 | power <<= 1 |
55 | 70 | return power |
56 | 71 |
|
57 | 72 |
|
58 | | -def pad_matrix_to_size(matrix: Matrix, target_size: int) -> Matrix: |
| 73 | +def pad_matrix(matrix: Matrix, size: int) -> Matrix: |
59 | 74 | """ |
60 | | - Pad a matrix with zeros to reach the given target size. |
| 75 | + Pad a matrix with zeros to reach the given size. |
| 76 | +
|
| 77 | + >>> pad_matrix([[1,2],[3,4]], 4) |
| 78 | + [[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]] |
61 | 79 | """ |
62 | | - rows, cols = len(matrix), len(matrix[0]) |
63 | | - padded_matrix = [[0] * target_size for _ in range(target_size)] |
| 80 | + rows = len(matrix) |
| 81 | + cols = len(matrix[0]) |
| 82 | + padded = [[0] * size for _ in range(size)] |
64 | 83 | for i in range(rows): |
65 | 84 | for j in range(cols): |
66 | | - padded_matrix[i][j] = matrix[i][j] |
67 | | - return padded_matrix |
| 85 | + padded[i][j] = matrix[i][j] |
| 86 | + return padded |
68 | 87 |
|
69 | 88 |
|
70 | | -def remove_matrix_padding( |
71 | | - matrix: Matrix, original_rows: int, original_cols: int |
72 | | -) -> Matrix: |
| 89 | +def unpad_matrix(matrix: Matrix, rows: int, cols: int) -> Matrix: |
73 | 90 | """ |
74 | | - Remove zero padding from a matrix to restore its original size. |
| 91 | + Remove padding from a matrix. |
| 92 | +
|
| 93 | + >>> unpad_matrix([[1,2,0],[3,4,0],[0,0,0]], 2, 2) |
| 94 | + [[1, 2], [3, 4]] |
75 | 95 | """ |
76 | | - return [row[:original_cols] for row in matrix[:original_rows]] |
| 96 | + return [row[:cols] for row in matrix[:rows]] |
77 | 97 |
|
78 | 98 |
|
79 | | -def split_matrix_into_quadrants( |
80 | | - matrix: Matrix, |
81 | | -) -> tuple[Matrix, Matrix, Matrix, Matrix]: |
| 99 | +def split(matrix: Matrix) -> tuple: |
82 | 100 | """ |
83 | | - Split a matrix into four equal quadrants: |
84 | | - top-left, top-right, bottom-left, bottom-right. |
| 101 | + Split a matrix into four quadrants (top-left, top-right, bottom-left, bottom-right). |
| 102 | +
|
| 103 | + >>> split([[1,2],[3,4]]) |
| 104 | + ([[1]], [[2]], [[3]], [[4]]) |
85 | 105 | """ |
86 | | - size = len(matrix) |
87 | | - mid = size // 2 |
| 106 | + n = len(matrix) |
| 107 | + mid = n // 2 |
88 | 108 | top_left = [[matrix[i][j] for j in range(mid)] for i in range(mid)] |
89 | | - top_right = [[matrix[i][j] for j in range(mid, size)] for i in range(mid)] |
90 | | - bottom_left = [[matrix[i][j] for j in range(mid)] for i in range(mid, size)] |
91 | | - bottom_right = [[matrix[i][j] for j in range(mid, size)] for i in range(mid, size)] |
| 109 | + top_right = [[matrix[i][j] for j in range(mid, n)] for i in range(mid)] |
| 110 | + bottom_left = [[matrix[i][j] for j in range(mid)] for i in range(mid, n)] |
| 111 | + bottom_right = [[matrix[i][j] for j in range(mid, n)] for i in range(mid, n)] |
92 | 112 | return top_left, top_right, bottom_left, bottom_right |
93 | 113 |
|
94 | 114 |
|
95 | | -def join_matrix_quadrants( |
96 | | - top_left: Matrix, top_right: Matrix, bottom_left: Matrix, bottom_right: Matrix |
97 | | -) -> Matrix: |
| 115 | +def join(c11: Matrix, c12: Matrix, c21: Matrix, c22: Matrix) -> Matrix: |
98 | 116 | """ |
99 | | - Join four quadrants into a single square matrix. |
100 | | - """ |
101 | | - quadrant_size = len(top_left) |
102 | | - full_size = quadrant_size * 2 |
103 | | - combined_matrix = [[0] * full_size for _ in range(full_size)] |
| 117 | + Join four quadrants into a single matrix. |
104 | 118 |
|
105 | | - for i in range(quadrant_size): |
106 | | - for j in range(quadrant_size): |
107 | | - combined_matrix[i][j] = top_left[i][j] |
108 | | - combined_matrix[i][j + quadrant_size] = top_right[i][j] |
109 | | - combined_matrix[i + quadrant_size][j] = bottom_left[i][j] |
110 | | - combined_matrix[i + quadrant_size][j + quadrant_size] = bottom_right[i][j] |
111 | | - return combined_matrix |
| 119 | + >>> join([[1]], [[2]], [[3]], [[4]]) |
| 120 | + [[1, 2], [3, 4]] |
| 121 | + """ |
| 122 | + n2 = len(c11) |
| 123 | + n = n2 * 2 |
| 124 | + result = [[0] * n for _ in range(n)] |
| 125 | + for i in range(n2): |
| 126 | + for j in range(n2): |
| 127 | + result[i][j] = c11[i][j] |
| 128 | + result[i][j + n2] = c12[i][j] |
| 129 | + result[i + n2][j] = c21[i][j] |
| 130 | + result[i + n2][j + n2] = c22[i][j] |
| 131 | + return result |
112 | 132 |
|
113 | 133 |
|
114 | | -def strassen_matrix_multiplication( |
115 | | - matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64 |
116 | | -) -> Matrix: |
| 134 | +def strassen(matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64) -> Matrix: |
117 | 135 | """ |
118 | 136 | Multiply two square matrices using Strassen's algorithm. |
119 | | - Uses naive multiplication for matrices smaller than the threshold. |
| 137 | + Uses naive multiplication for matrices smaller than threshold. |
| 138 | +
|
| 139 | + >>> strassen([[1,2],[3,4]], [[5,6],[7,8]]) |
| 140 | + [[19, 22], [43, 50]] |
120 | 141 | """ |
121 | 142 | assert len(matrix_a) == len(matrix_a[0]) == len(matrix_b) == len(matrix_b[0]), ( |
122 | | - "Strassen's algorithm supports only square matrices." |
| 143 | + "Only square matrices supported" |
123 | 144 | ) |
124 | 145 |
|
125 | | - original_size = len(matrix_a) |
126 | | - if original_size == 0: |
| 146 | + n_orig = len(matrix_a) |
| 147 | + if n_orig == 0: |
127 | 148 | return [] |
128 | 149 |
|
129 | | - # Pad matrices to next power of two for even splitting |
130 | | - if (padded_size := get_next_power_of_two(original_size)) != original_size: |
131 | | - matrix_a = pad_matrix_to_size(matrix_a, padded_size) |
132 | | - matrix_b = pad_matrix_to_size(matrix_b, padded_size) |
133 | | - |
134 | | - result_padded = _strassen_recursive_multiply(matrix_a, matrix_b, threshold) |
135 | | - return remove_matrix_padding(result_padded, original_size, original_size) |
136 | | - |
| 150 | + if (m := next_power_of_two(n_orig)) != n_orig: |
| 151 | + a_pad = pad_matrix(matrix_a, m) |
| 152 | + b_pad = pad_matrix(matrix_b, m) |
| 153 | + else: |
| 154 | + a_pad, b_pad = matrix_a, matrix_b |
137 | 155 |
|
138 | | -def _strassen_recursive_multiply( |
139 | | - matrix_a: Matrix, matrix_b: Matrix, threshold: int |
140 | | -) -> Matrix: |
141 | | - """ |
142 | | - Recursive implementation of Strassen's algorithm. |
143 | | - """ |
144 | | - size = len(matrix_a) |
| 156 | + c_pad = _strassen_recursive(a_pad, b_pad, threshold) |
| 157 | + return unpad_matrix(c_pad, n_orig, n_orig) |
145 | 158 |
|
146 | | - # Base case: use naive multiplication for small matrices |
147 | | - if size <= threshold: |
148 | | - return multiply_matrices_naive(matrix_a, matrix_b) |
149 | 159 |
|
150 | | - if size == 1: |
| 160 | +def _strassen_recursive(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> Matrix: |
| 161 | + n = len(matrix_a) |
| 162 | + if n <= threshold: |
| 163 | + return naive_mul(matrix_a, matrix_b) |
| 164 | + if n == 1: |
151 | 165 | return [[matrix_a[0][0] * matrix_b[0][0]]] |
152 | 166 |
|
153 | | - # Split matrices into quadrants |
154 | | - a11, a12, a21, a22 = split_matrix_into_quadrants(matrix_a) |
155 | | - b11, b12, b21, b22 = split_matrix_into_quadrants(matrix_b) |
| 167 | + a11, a12, a21, a22 = split(matrix_a) |
| 168 | + b11, b12, b21, b22 = split(matrix_b) |
156 | 169 |
|
157 | | - # Compute the 7 Strassen products |
158 | | - p1 = _strassen_recursive_multiply( |
159 | | - add_matrices(a11, a22), add_matrices(b11, b22), threshold |
160 | | - ) |
161 | | - p2 = _strassen_recursive_multiply(add_matrices(a21, a22), b11, threshold) |
162 | | - p3 = _strassen_recursive_multiply(a11, subtract_matrices(b12, b22), threshold) |
163 | | - p4 = _strassen_recursive_multiply(a22, subtract_matrices(b21, b11), threshold) |
164 | | - p5 = _strassen_recursive_multiply(add_matrices(a11, a12), b22, threshold) |
165 | | - p6 = _strassen_recursive_multiply( |
166 | | - subtract_matrices(a21, a11), add_matrices(b11, b12), threshold |
167 | | - ) |
168 | | - p7 = _strassen_recursive_multiply( |
169 | | - subtract_matrices(a12, a22), add_matrices(b21, b22), threshold |
170 | | - ) |
| 170 | + m1 = _strassen_recursive(add(a11, a22), add(b11, b22), threshold) |
| 171 | + m2 = _strassen_recursive(add(a21, a22), b11, threshold) |
| 172 | + m3 = _strassen_recursive(a11, sub(b12, b22), threshold) |
| 173 | + m4 = _strassen_recursive(a22, sub(b21, b11), threshold) |
| 174 | + m5 = _strassen_recursive(add(a11, a12), b22, threshold) |
| 175 | + m6 = _strassen_recursive(sub(a21, a11), add(b11, b12), threshold) |
| 176 | + m7 = _strassen_recursive(sub(a12, a22), add(b21, b22), threshold) |
171 | 177 |
|
172 | | - # Combine partial results into final quadrants |
173 | | - c11 = add_matrices(subtract_matrices(add_matrices(p1, p4), p5), p7) |
174 | | - c12 = add_matrices(p3, p5) |
175 | | - c21 = add_matrices(p2, p4) |
176 | | - c22 = add_matrices(subtract_matrices(add_matrices(p1, p3), p2), p6) |
| 178 | + c11 = add(sub(add(m1, m4), m5), m7) |
| 179 | + c12 = add(m3, m5) |
| 180 | + c21 = add(m2, m4) |
| 181 | + c22 = add(sub(add(m1, m3), m2), m6) |
177 | 182 |
|
178 | | - return join_matrix_quadrants(c11, c12, c21, c22) |
| 183 | + return join(c11, c12, c21, c22) |
179 | 184 |
|
180 | 185 |
|
181 | 186 | if __name__ == "__main__": |
182 | | - matrix_A = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] |
183 | | - matrix_B = [[9, 8, 7], [6, 5, 4], [3, 2, 1]] |
| 187 | + A = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] |
| 188 | + B = [[9, 8, 7], [6, 5, 4], [3, 2, 1]] |
184 | 189 |
|
185 | | - result_matrix = strassen_matrix_multiplication(matrix_A, matrix_B, threshold=1) |
186 | | - print("A × B =") |
187 | | - for row in result_matrix: |
| 190 | + C = strassen(A, B, threshold=1) |
| 191 | + print("A * B =") |
| 192 | + for row in C: |
188 | 193 | print(row) |
189 | 194 |
|
190 | | - expected_matrix = multiply_matrices_naive(matrix_A, matrix_B) |
191 | | - assert expected_matrix == result_matrix, ( |
192 | | - "Strassen result differs from naive multiplication!" |
193 | | - ) |
| 195 | + expected = naive_mul(A, B) |
| 196 | + assert expected == C, "Strassen result differs from naive multiplication!" |
194 | 197 | print("Verified: result matches naive multiplication.") |
0 commit comments