Skip to content

Commit 8a070b5

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 656b092 commit 8a070b5

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

matrix/strassen_matrix_multiply.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,18 @@ def pad_matrix_to_size(matrix: Matrix, target_size: int) -> Matrix:
6767
return padded_matrix
6868

6969

70-
def remove_matrix_padding(matrix: Matrix, original_rows: int, original_cols: int) -> Matrix:
70+
def remove_matrix_padding(
71+
matrix: Matrix, original_rows: int, original_cols: int
72+
) -> Matrix:
7173
"""
7274
Remove zero padding from a matrix to restore its original size.
7375
"""
7476
return [row[:original_cols] for row in matrix[:original_rows]]
7577

7678

77-
def split_matrix_into_quadrants(matrix: Matrix) -> tuple[Matrix, Matrix, Matrix, Matrix]:
79+
def split_matrix_into_quadrants(
80+
matrix: Matrix,
81+
) -> tuple[Matrix, Matrix, Matrix, Matrix]:
7882
"""
7983
Split a matrix into four equal quadrants:
8084
top-left, top-right, bottom-left, bottom-right.
@@ -107,7 +111,9 @@ def join_matrix_quadrants(
107111
return combined_matrix
108112

109113

110-
def strassen_matrix_multiplication(matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64) -> Matrix:
114+
def strassen_matrix_multiplication(
115+
matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64
116+
) -> Matrix:
111117
"""
112118
Multiply two square matrices using Strassen's algorithm.
113119
Uses naive multiplication for matrices smaller than the threshold.
@@ -121,16 +127,17 @@ def strassen_matrix_multiplication(matrix_a: Matrix, matrix_b: Matrix, threshold
121127
return []
122128

123129
# Pad matrices to next power of two for even splitting
124-
padded_size = get_next_power_of_two(original_size)
125-
if padded_size != original_size:
130+
if (padded_size := get_next_power_of_two(original_size)) != original_size:
126131
matrix_a = pad_matrix_to_size(matrix_a, padded_size)
127132
matrix_b = pad_matrix_to_size(matrix_b, padded_size)
128133

129134
result_padded = _strassen_recursive_multiply(matrix_a, matrix_b, threshold)
130135
return remove_matrix_padding(result_padded, original_size, original_size)
131136

132137

133-
def _strassen_recursive_multiply(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> Matrix:
138+
def _strassen_recursive_multiply(
139+
matrix_a: Matrix, matrix_b: Matrix, threshold: int
140+
) -> Matrix:
134141
"""
135142
Recursive implementation of Strassen's algorithm.
136143
"""
@@ -148,13 +155,19 @@ def _strassen_recursive_multiply(matrix_a: Matrix, matrix_b: Matrix, threshold:
148155
b11, b12, b21, b22 = split_matrix_into_quadrants(matrix_b)
149156

150157
# Compute the 7 Strassen products
151-
p1 = _strassen_recursive_multiply(add_matrices(a11, a22), add_matrices(b11, b22), threshold)
158+
p1 = _strassen_recursive_multiply(
159+
add_matrices(a11, a22), add_matrices(b11, b22), threshold
160+
)
152161
p2 = _strassen_recursive_multiply(add_matrices(a21, a22), b11, threshold)
153162
p3 = _strassen_recursive_multiply(a11, subtract_matrices(b12, b22), threshold)
154163
p4 = _strassen_recursive_multiply(a22, subtract_matrices(b21, b11), threshold)
155164
p5 = _strassen_recursive_multiply(add_matrices(a11, a12), b22, threshold)
156-
p6 = _strassen_recursive_multiply(subtract_matrices(a21, a11), add_matrices(b11, b12), threshold)
157-
p7 = _strassen_recursive_multiply(subtract_matrices(a12, a22), add_matrices(b21, b22), threshold)
165+
p6 = _strassen_recursive_multiply(
166+
subtract_matrices(a21, a11), add_matrices(b11, b12), threshold
167+
)
168+
p7 = _strassen_recursive_multiply(
169+
subtract_matrices(a12, a22), add_matrices(b21, b22), threshold
170+
)
158171

159172
# Combine partial results into final quadrants
160173
c11 = add_matrices(subtract_matrices(add_matrices(p1, p4), p5), p7)
@@ -175,5 +188,7 @@ def _strassen_recursive_multiply(matrix_a: Matrix, matrix_b: Matrix, threshold:
175188
print(row)
176189

177190
expected_matrix = multiply_matrices_naive(matrix_A, matrix_B)
178-
assert expected_matrix == result_matrix, "Strassen result differs from naive multiplication!"
191+
assert expected_matrix == result_matrix, (
192+
"Strassen result differs from naive multiplication!"
193+
)
179194
print("Verified: result matches naive multiplication.")

0 commit comments

Comments
 (0)