Skip to content

Commit c291794

Browse files
committed
feat: Strassen's matrix multiplication algorithm added
1 parent a337a94 commit c291794

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

matrix/strassen_matrix_multiply.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
https://en.wikipedia.org/wiki/Strassen_algorithm
1313
"""
1414

15-
Matrix = list[list[int]]
15+
matrix = list[list[int]]
1616

1717

18-
def add(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
18+
def add(matrix_a: matrix, matrix_b: matrix) -> matrix:
1919
"""
2020
Add two square matrices of the same size.
2121
@@ -26,7 +26,7 @@ def add(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
2626
return [[matrix_a[i][j] + matrix_b[i][j] for j in range(n)] for i in range(n)]
2727

2828

29-
def subtract(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
29+
def subtract(matrix_a: matrix, matrix_b: matrix) -> matrix:
3030
"""
3131
Subtract matrix_b from matrix_a.
3232
@@ -37,7 +37,7 @@ def subtract(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
3737
return [[matrix_a[i][j] - matrix_b[i][j] for j in range(n)] for i in range(n)]
3838

3939

40-
def naive_multiplication(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
40+
def naive_multiplication(matrix_a: matrix, matrix_b: matrix) -> matrix:
4141
"""
4242
Multiply two square matrices using the naive O(n^3) method.
4343
@@ -70,7 +70,7 @@ def next_power_of_two(n: int) -> int:
7070
return power
7171

7272

73-
def pad_matrix(matrix: Matrix, size: int) -> Matrix:
73+
def pad_matrix(matrix: matrix, size: int) -> matrix:
7474
"""
7575
Pad a matrix with zeros to reach the given size.
7676
@@ -86,7 +86,7 @@ def pad_matrix(matrix: Matrix, size: int) -> Matrix:
8686
return padded
8787

8888

89-
def unpad_matrix(matrix: Matrix, rows: int, cols: int) -> Matrix:
89+
def unpad_matrix(matrix: matrix, rows: int, cols: int) -> matrix:
9090
"""
9191
Remove padding from a matrix.
9292
@@ -96,7 +96,7 @@ def unpad_matrix(matrix: Matrix, rows: int, cols: int) -> Matrix:
9696
return [row[:cols] for row in matrix[:rows]]
9797

9898

99-
def split(matrix: Matrix) -> tuple:
99+
def split(matrix: matrix) -> tuple:
100100
"""
101101
Split a matrix into four quadrants (top-left, top-right, bottom-left, bottom-right).
102102
@@ -112,7 +112,7 @@ def split(matrix: Matrix) -> tuple:
112112
return top_left, top_right, bottom_left, bottom_right
113113

114114

115-
def join(c11: Matrix, c12: Matrix, c21: Matrix, c22: Matrix) -> Matrix:
115+
def join(c11: matrix, c12: matrix, c21: matrix, c22: matrix) -> matrix:
116116
"""
117117
Join four quadrants into a single matrix.
118118
@@ -131,7 +131,7 @@ def join(c11: Matrix, c12: Matrix, c21: Matrix, c22: Matrix) -> Matrix:
131131
return result
132132

133133

134-
def strassen(matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64) -> Matrix:
134+
def strassen(matrix_a: matrix, matrix_b: matrix, threshold: int = 64) -> matrix:
135135
"""
136136
Multiply two square matrices using Strassen's algorithm.
137137
Uses naive multiplication for matrices smaller than threshold.
@@ -157,7 +157,7 @@ def strassen(matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64) -> Matrix:
157157
return unpad_matrix(c_pad, n_orig, n_orig)
158158

159159

160-
def _strassen_recursive(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> Matrix:
160+
def _strassen_recursive(matrix_a: matrix, matrix_b: matrix, threshold: int) -> matrix:
161161
"""
162162
Recursive helper for Strassen's algorithm.
163163

0 commit comments

Comments
 (0)