@@ -67,14 +67,18 @@ def pad_matrix_to_size(matrix: Matrix, target_size: int) -> Matrix:
6767 return padded_matrix
6868
6969
70- def remove_matrix_padding (matrix : Matrix , original_rows : int , original_cols : int ) -> Matrix :
70+ def remove_matrix_padding (
71+ matrix : Matrix , original_rows : int , original_cols : int
72+ ) -> Matrix :
7173 """
7274 Remove zero padding from a matrix to restore its original size.
7375 """
7476 return [row [:original_cols ] for row in matrix [:original_rows ]]
7577
7678
77- def split_matrix_into_quadrants (matrix : Matrix ) -> tuple [Matrix , Matrix , Matrix , Matrix ]:
79+ def split_matrix_into_quadrants (
80+ matrix : Matrix ,
81+ ) -> tuple [Matrix , Matrix , Matrix , Matrix ]:
7882 """
7983 Split a matrix into four equal quadrants:
8084 top-left, top-right, bottom-left, bottom-right.
@@ -107,7 +111,9 @@ def join_matrix_quadrants(
107111 return combined_matrix
108112
109113
110- def strassen_matrix_multiplication (matrix_a : Matrix , matrix_b : Matrix , threshold : int = 64 ) -> Matrix :
114+ def strassen_matrix_multiplication (
115+ matrix_a : Matrix , matrix_b : Matrix , threshold : int = 64
116+ ) -> Matrix :
111117 """
112118 Multiply two square matrices using Strassen's algorithm.
113119 Uses naive multiplication for matrices smaller than the threshold.
@@ -121,16 +127,17 @@ def strassen_matrix_multiplication(matrix_a: Matrix, matrix_b: Matrix, threshold
121127 return []
122128
123129 # Pad matrices to next power of two for even splitting
124- padded_size = get_next_power_of_two (original_size )
125- if padded_size != original_size :
130+ if (padded_size := get_next_power_of_two (original_size )) != original_size :
126131 matrix_a = pad_matrix_to_size (matrix_a , padded_size )
127132 matrix_b = pad_matrix_to_size (matrix_b , padded_size )
128133
129134 result_padded = _strassen_recursive_multiply (matrix_a , matrix_b , threshold )
130135 return remove_matrix_padding (result_padded , original_size , original_size )
131136
132137
133- def _strassen_recursive_multiply (matrix_a : Matrix , matrix_b : Matrix , threshold : int ) -> Matrix :
138+ def _strassen_recursive_multiply (
139+ matrix_a : Matrix , matrix_b : Matrix , threshold : int
140+ ) -> Matrix :
134141 """
135142 Recursive implementation of Strassen's algorithm.
136143 """
@@ -148,13 +155,19 @@ def _strassen_recursive_multiply(matrix_a: Matrix, matrix_b: Matrix, threshold:
148155 b11 , b12 , b21 , b22 = split_matrix_into_quadrants (matrix_b )
149156
150157 # Compute the 7 Strassen products
151- p1 = _strassen_recursive_multiply (add_matrices (a11 , a22 ), add_matrices (b11 , b22 ), threshold )
158+ p1 = _strassen_recursive_multiply (
159+ add_matrices (a11 , a22 ), add_matrices (b11 , b22 ), threshold
160+ )
152161 p2 = _strassen_recursive_multiply (add_matrices (a21 , a22 ), b11 , threshold )
153162 p3 = _strassen_recursive_multiply (a11 , subtract_matrices (b12 , b22 ), threshold )
154163 p4 = _strassen_recursive_multiply (a22 , subtract_matrices (b21 , b11 ), threshold )
155164 p5 = _strassen_recursive_multiply (add_matrices (a11 , a12 ), b22 , threshold )
156- p6 = _strassen_recursive_multiply (subtract_matrices (a21 , a11 ), add_matrices (b11 , b12 ), threshold )
157- p7 = _strassen_recursive_multiply (subtract_matrices (a12 , a22 ), add_matrices (b21 , b22 ), threshold )
165+ p6 = _strassen_recursive_multiply (
166+ subtract_matrices (a21 , a11 ), add_matrices (b11 , b12 ), threshold
167+ )
168+ p7 = _strassen_recursive_multiply (
169+ subtract_matrices (a12 , a22 ), add_matrices (b21 , b22 ), threshold
170+ )
158171
159172 # Combine partial results into final quadrants
160173 c11 = add_matrices (subtract_matrices (add_matrices (p1 , p4 ), p5 ), p7 )
@@ -175,5 +188,7 @@ def _strassen_recursive_multiply(matrix_a: Matrix, matrix_b: Matrix, threshold:
175188 print (row )
176189
177190 expected_matrix = multiply_matrices_naive (matrix_A , matrix_B )
178- assert expected_matrix == result_matrix , "Strassen result differs from naive multiplication!"
191+ assert expected_matrix == result_matrix , (
192+ "Strassen result differs from naive multiplication!"
193+ )
179194 print ("Verified: result matches naive multiplication." )
0 commit comments