Skip to content

Commit 5e951b6

Browse files
committed
feat: Strassen's matrix multiplication algorithm added
1 parent 8a070b5 commit 5e951b6

File tree

1 file changed

+115
-112
lines changed

1 file changed

+115
-112
lines changed

matrix/strassen_matrix_multiply.py

Lines changed: 115 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
Strassen's Matrix Multiplication Algorithm (Descriptive Version)
3-
---------------------------------------------------------------
2+
Strassen's Matrix Multiplication Algorithm
3+
------------------------------------------
44
An optimized divide-and-conquer algorithm for matrix multiplication that
55
reduces the number of multiplications from 8 (in the naive approach)
66
to 7 per recursion step.
@@ -15,180 +15,183 @@
1515
Matrix = list[list[int]]
1616

1717

18-
def add_matrices(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
18+
def add(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
1919
"""
2020
Add two square matrices of the same size.
21+
22+
>>> add([[1,2],[3,4]], [[5,6],[7,8]])
23+
[[6, 8], [10, 12]]
2124
"""
22-
size = len(matrix_a)
23-
return [[matrix_a[i][j] + matrix_b[i][j] for j in range(size)] for i in range(size)]
25+
n = len(matrix_a)
26+
return [[matrix_a[i][j] + matrix_b[i][j] for j in range(n)] for i in range(n)]
2427

2528

26-
def subtract_matrices(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
29+
def sub(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
2730
"""
2831
Subtract matrix_b from matrix_a.
32+
33+
>>> sub([[5,6],[7,8]], [[1,2],[3,4]])
34+
[[4, 4], [4, 4]]
2935
"""
30-
size = len(matrix_a)
31-
return [[matrix_a[i][j] - matrix_b[i][j] for j in range(size)] for i in range(size)]
36+
n = len(matrix_a)
37+
return [[matrix_a[i][j] - matrix_b[i][j] for j in range(n)] for i in range(n)]
3238

3339

34-
def multiply_matrices_naive(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
40+
def naive_mul(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
3541
"""
3642
Multiply two square matrices using the naive O(n^3) method.
37-
"""
38-
size = len(matrix_a)
39-
result_matrix = [[0] * size for _ in range(size)]
4043
41-
for i in range(size):
42-
for k in range(size):
43-
for j in range(size):
44-
result_matrix[i][j] += matrix_a[i][k] * matrix_b[k][j]
45-
return result_matrix
44+
>>> naive_mul([[1,2],[3,4]], [[5,6],[7,8]])
45+
[[19, 22], [43, 50]]
46+
"""
47+
n = len(matrix_a)
48+
result = [[0] * n for _ in range(n)]
49+
for i in range(n):
50+
row_a = matrix_a[i]
51+
row_result = result[i]
52+
for k in range(n):
53+
a_ik = row_a[k]
54+
col_b = matrix_b[k]
55+
for j in range(n):
56+
row_result[j] += a_ik * col_b[j]
57+
return result
4658

4759

48-
def get_next_power_of_two(n: int) -> int:
60+
def next_power_of_two(n: int) -> int:
4961
"""
5062
Return the next power of two greater than or equal to n.
63+
64+
>>> next_power_of_two(5)
65+
8
5166
"""
5267
power = 1
5368
while power < n:
5469
power <<= 1
5570
return power
5671

5772

58-
def pad_matrix_to_size(matrix: Matrix, target_size: int) -> Matrix:
73+
def pad_matrix(matrix: Matrix, size: int) -> Matrix:
5974
"""
60-
Pad a matrix with zeros to reach the given target size.
75+
Pad a matrix with zeros to reach the given size.
76+
77+
>>> pad_matrix([[1,2],[3,4]], 4)
78+
[[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
6179
"""
62-
rows, cols = len(matrix), len(matrix[0])
63-
padded_matrix = [[0] * target_size for _ in range(target_size)]
80+
rows = len(matrix)
81+
cols = len(matrix[0])
82+
padded = [[0] * size for _ in range(size)]
6483
for i in range(rows):
6584
for j in range(cols):
66-
padded_matrix[i][j] = matrix[i][j]
67-
return padded_matrix
85+
padded[i][j] = matrix[i][j]
86+
return padded
6887

6988

70-
def remove_matrix_padding(
71-
matrix: Matrix, original_rows: int, original_cols: int
72-
) -> Matrix:
89+
def unpad_matrix(matrix: Matrix, rows: int, cols: int) -> Matrix:
7390
"""
74-
Remove zero padding from a matrix to restore its original size.
91+
Remove padding from a matrix.
92+
93+
>>> unpad_matrix([[1,2,0],[3,4,0],[0,0,0]], 2, 2)
94+
[[1, 2], [3, 4]]
7595
"""
76-
return [row[:original_cols] for row in matrix[:original_rows]]
96+
return [row[:cols] for row in matrix[:rows]]
7797

7898

79-
def split_matrix_into_quadrants(
80-
matrix: Matrix,
81-
) -> tuple[Matrix, Matrix, Matrix, Matrix]:
99+
def split(matrix: Matrix) -> tuple:
82100
"""
83-
Split a matrix into four equal quadrants:
84-
top-left, top-right, bottom-left, bottom-right.
101+
Split a matrix into four quadrants (top-left, top-right, bottom-left, bottom-right).
102+
103+
>>> split([[1,2],[3,4]])
104+
([[1]], [[2]], [[3]], [[4]])
85105
"""
86-
size = len(matrix)
87-
mid = size // 2
106+
n = len(matrix)
107+
mid = n // 2
88108
top_left = [[matrix[i][j] for j in range(mid)] for i in range(mid)]
89-
top_right = [[matrix[i][j] for j in range(mid, size)] for i in range(mid)]
90-
bottom_left = [[matrix[i][j] for j in range(mid)] for i in range(mid, size)]
91-
bottom_right = [[matrix[i][j] for j in range(mid, size)] for i in range(mid, size)]
109+
top_right = [[matrix[i][j] for j in range(mid, n)] for i in range(mid)]
110+
bottom_left = [[matrix[i][j] for j in range(mid)] for i in range(mid, n)]
111+
bottom_right = [[matrix[i][j] for j in range(mid, n)] for i in range(mid, n)]
92112
return top_left, top_right, bottom_left, bottom_right
93113

94114

95-
def join_matrix_quadrants(
96-
top_left: Matrix, top_right: Matrix, bottom_left: Matrix, bottom_right: Matrix
97-
) -> Matrix:
115+
def join(c11: Matrix, c12: Matrix, c21: Matrix, c22: Matrix) -> Matrix:
98116
"""
99-
Join four quadrants into a single square matrix.
100-
"""
101-
quadrant_size = len(top_left)
102-
full_size = quadrant_size * 2
103-
combined_matrix = [[0] * full_size for _ in range(full_size)]
117+
Join four quadrants into a single matrix.
104118
105-
for i in range(quadrant_size):
106-
for j in range(quadrant_size):
107-
combined_matrix[i][j] = top_left[i][j]
108-
combined_matrix[i][j + quadrant_size] = top_right[i][j]
109-
combined_matrix[i + quadrant_size][j] = bottom_left[i][j]
110-
combined_matrix[i + quadrant_size][j + quadrant_size] = bottom_right[i][j]
111-
return combined_matrix
119+
>>> join([[1]], [[2]], [[3]], [[4]])
120+
[[1, 2], [3, 4]]
121+
"""
122+
n2 = len(c11)
123+
n = n2 * 2
124+
result = [[0] * n for _ in range(n)]
125+
for i in range(n2):
126+
for j in range(n2):
127+
result[i][j] = c11[i][j]
128+
result[i][j + n2] = c12[i][j]
129+
result[i + n2][j] = c21[i][j]
130+
result[i + n2][j + n2] = c22[i][j]
131+
return result
112132

113133

114-
def strassen_matrix_multiplication(
115-
matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64
116-
) -> Matrix:
134+
def strassen(matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64) -> Matrix:
117135
"""
118136
Multiply two square matrices using Strassen's algorithm.
119-
Uses naive multiplication for matrices smaller than the threshold.
137+
Uses naive multiplication for matrices smaller than threshold.
138+
139+
>>> strassen([[1,2],[3,4]], [[5,6],[7,8]])
140+
[[19, 22], [43, 50]]
120141
"""
121142
assert len(matrix_a) == len(matrix_a[0]) == len(matrix_b) == len(matrix_b[0]), (
122-
"Strassen's algorithm supports only square matrices."
143+
"Only square matrices supported"
123144
)
124145

125-
original_size = len(matrix_a)
126-
if original_size == 0:
146+
n_orig = len(matrix_a)
147+
if n_orig == 0:
127148
return []
128149

129-
# Pad matrices to next power of two for even splitting
130-
if (padded_size := get_next_power_of_two(original_size)) != original_size:
131-
matrix_a = pad_matrix_to_size(matrix_a, padded_size)
132-
matrix_b = pad_matrix_to_size(matrix_b, padded_size)
133-
134-
result_padded = _strassen_recursive_multiply(matrix_a, matrix_b, threshold)
135-
return remove_matrix_padding(result_padded, original_size, original_size)
136-
150+
if (m := next_power_of_two(n_orig)) != n_orig:
151+
a_pad = pad_matrix(matrix_a, m)
152+
b_pad = pad_matrix(matrix_b, m)
153+
else:
154+
a_pad, b_pad = matrix_a, matrix_b
137155

138-
def _strassen_recursive_multiply(
139-
matrix_a: Matrix, matrix_b: Matrix, threshold: int
140-
) -> Matrix:
141-
"""
142-
Recursive implementation of Strassen's algorithm.
143-
"""
144-
size = len(matrix_a)
156+
c_pad = _strassen_recursive(a_pad, b_pad, threshold)
157+
return unpad_matrix(c_pad, n_orig, n_orig)
145158

146-
# Base case: use naive multiplication for small matrices
147-
if size <= threshold:
148-
return multiply_matrices_naive(matrix_a, matrix_b)
149159

150-
if size == 1:
160+
def _strassen_recursive(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> Matrix:
161+
n = len(matrix_a)
162+
if n <= threshold:
163+
return naive_mul(matrix_a, matrix_b)
164+
if n == 1:
151165
return [[matrix_a[0][0] * matrix_b[0][0]]]
152166

153-
# Split matrices into quadrants
154-
a11, a12, a21, a22 = split_matrix_into_quadrants(matrix_a)
155-
b11, b12, b21, b22 = split_matrix_into_quadrants(matrix_b)
167+
a11, a12, a21, a22 = split(matrix_a)
168+
b11, b12, b21, b22 = split(matrix_b)
156169

157-
# Compute the 7 Strassen products
158-
p1 = _strassen_recursive_multiply(
159-
add_matrices(a11, a22), add_matrices(b11, b22), threshold
160-
)
161-
p2 = _strassen_recursive_multiply(add_matrices(a21, a22), b11, threshold)
162-
p3 = _strassen_recursive_multiply(a11, subtract_matrices(b12, b22), threshold)
163-
p4 = _strassen_recursive_multiply(a22, subtract_matrices(b21, b11), threshold)
164-
p5 = _strassen_recursive_multiply(add_matrices(a11, a12), b22, threshold)
165-
p6 = _strassen_recursive_multiply(
166-
subtract_matrices(a21, a11), add_matrices(b11, b12), threshold
167-
)
168-
p7 = _strassen_recursive_multiply(
169-
subtract_matrices(a12, a22), add_matrices(b21, b22), threshold
170-
)
170+
m1 = _strassen_recursive(add(a11, a22), add(b11, b22), threshold)
171+
m2 = _strassen_recursive(add(a21, a22), b11, threshold)
172+
m3 = _strassen_recursive(a11, sub(b12, b22), threshold)
173+
m4 = _strassen_recursive(a22, sub(b21, b11), threshold)
174+
m5 = _strassen_recursive(add(a11, a12), b22, threshold)
175+
m6 = _strassen_recursive(sub(a21, a11), add(b11, b12), threshold)
176+
m7 = _strassen_recursive(sub(a12, a22), add(b21, b22), threshold)
171177

172-
# Combine partial results into final quadrants
173-
c11 = add_matrices(subtract_matrices(add_matrices(p1, p4), p5), p7)
174-
c12 = add_matrices(p3, p5)
175-
c21 = add_matrices(p2, p4)
176-
c22 = add_matrices(subtract_matrices(add_matrices(p1, p3), p2), p6)
178+
c11 = add(sub(add(m1, m4), m5), m7)
179+
c12 = add(m3, m5)
180+
c21 = add(m2, m4)
181+
c22 = add(sub(add(m1, m3), m2), m6)
177182

178-
return join_matrix_quadrants(c11, c12, c21, c22)
183+
return join(c11, c12, c21, c22)
179184

180185

181186
if __name__ == "__main__":
182-
matrix_A = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
183-
matrix_B = [[9, 8, 7], [6, 5, 4], [3, 2, 1]]
187+
A = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
188+
B = [[9, 8, 7], [6, 5, 4], [3, 2, 1]]
184189

185-
result_matrix = strassen_matrix_multiplication(matrix_A, matrix_B, threshold=1)
186-
print("A × B =")
187-
for row in result_matrix:
190+
C = strassen(A, B, threshold=1)
191+
print("A * B =")
192+
for row in C:
188193
print(row)
189194

190-
expected_matrix = multiply_matrices_naive(matrix_A, matrix_B)
191-
assert expected_matrix == result_matrix, (
192-
"Strassen result differs from naive multiplication!"
193-
)
195+
expected = naive_mul(A, B)
196+
assert expected == C, "Strassen result differs from naive multiplication!"
194197
print("Verified: result matches naive multiplication.")

0 commit comments

Comments
 (0)