Skip to content

Commit 2f6508a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6205411 commit 2f6508a

File tree

1 file changed

+30
-26
lines changed

1 file changed

+30
-26
lines changed

linear_algebra/matrix_trace.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
def trace(matrix: NDArray[float64]) -> float:
1616
"""
1717
Calculate the trace of a square matrix.
18-
18+
1919
The trace is the sum of the diagonal elements of a square matrix.
2020
2121
Parameters:
@@ -43,95 +43,99 @@ def trace(matrix: NDArray[float64]) -> float:
4343
"""
4444
if matrix.shape[0] != matrix.shape[1]:
4545
raise ValueError("Matrix must be square")
46-
46+
4747
return float(np.sum(np.diag(matrix)))
4848

4949

5050
def trace_properties_demo(matrix: NDArray[float64]) -> dict:
5151
"""
5252
Demonstrate various properties of the trace operation.
53-
53+
5454
Parameters:
5555
matrix (NDArray[float64]): A square matrix
56-
56+
5757
Returns:
5858
dict: Dictionary containing trace properties and calculations
5959
"""
6060
if matrix.shape[0] != matrix.shape[1]:
6161
raise ValueError("Matrix must be square")
62-
62+
6363
n = matrix.shape[0]
64-
64+
6565
# Calculate trace
6666
tr = trace(matrix)
67-
67+
6868
# Calculate transpose trace (should be equal to original)
6969
tr_transpose = trace(matrix.T)
70-
70+
7171
# Calculate trace of scalar multiple
7272
scalar = 2.0
7373
tr_scalar = trace(scalar * matrix)
74-
74+
7575
# Create identity matrix for comparison
7676
identity = np.eye(n, dtype=float64)
7777
tr_identity = trace(identity)
78-
78+
7979
return {
8080
"original_trace": tr,
8181
"transpose_trace": tr_transpose,
8282
"scalar_multiple_trace": tr_scalar,
8383
"scalar_factor": scalar,
8484
"identity_trace": tr_identity,
8585
"trace_equals_transpose": abs(tr - tr_transpose) < 1e-10,
86-
"scalar_property_check": abs(tr_scalar - scalar * tr) < 1e-10
86+
"scalar_property_check": abs(tr_scalar - scalar * tr) < 1e-10,
8787
}
8888

8989

9090
def test_trace() -> None:
9191
"""
9292
Test function for matrix trace calculation.
93-
93+
9494
>>> test_trace() # self running tests
9595
"""
9696
# Test 1: 2x2 matrix
9797
matrix_2x2 = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float)
9898
tr_2x2 = trace(matrix_2x2)
9999
assert abs(tr_2x2 - 5.0) < 1e-10, "2x2 trace calculation failed"
100-
100+
101101
# Test 2: 3x3 matrix
102-
matrix_3x3 = np.array([[2.0, -1.0, 3.0],
103-
[4.0, 5.0, -2.0],
104-
[1.0, 0.0, 7.0]], dtype=float)
102+
matrix_3x3 = np.array(
103+
[[2.0, -1.0, 3.0], [4.0, 5.0, -2.0], [1.0, 0.0, 7.0]], dtype=float
104+
)
105105
tr_3x3 = trace(matrix_3x3)
106106
assert abs(tr_3x3 - 14.0) < 1e-10, "3x3 trace calculation failed"
107-
107+
108108
# Test 3: Identity matrix
109109
identity_4x4 = np.eye(4, dtype=float)
110110
tr_identity = trace(identity_4x4)
111-
assert abs(tr_identity - 4.0) < 1e-10, "Identity matrix trace should equal dimension"
112-
111+
assert abs(tr_identity - 4.0) < 1e-10, (
112+
"Identity matrix trace should equal dimension"
113+
)
114+
113115
# Test 4: Zero matrix
114116
zero_matrix = np.zeros((3, 3), dtype=float)
115117
tr_zero = trace(zero_matrix)
116118
assert abs(tr_zero) < 1e-10, "Zero matrix should have zero trace"
117-
119+
118120
# Test 5: Trace properties
119-
test_matrix = np.array([[1.0, 2.0, 3.0],
120-
[4.0, 5.0, 6.0],
121-
[7.0, 8.0, 9.0]], dtype=float)
121+
test_matrix = np.array(
122+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=float
123+
)
122124
properties = trace_properties_demo(test_matrix)
123125
assert properties["trace_equals_transpose"], "Trace should equal transpose trace"
124126
assert properties["scalar_property_check"], "Scalar multiplication property failed"
125-
127+
126128
# Test 6: Diagonal matrix
127129
diagonal_matrix = np.diag([1.0, 2.0, 3.0, 4.0])
128130
tr_diagonal = trace(diagonal_matrix)
129131
expected = 1.0 + 2.0 + 3.0 + 4.0
130-
assert abs(tr_diagonal - expected) < 1e-10, "Diagonal matrix trace should equal sum of diagonal elements"
132+
assert abs(tr_diagonal - expected) < 1e-10, (
133+
"Diagonal matrix trace should equal sum of diagonal elements"
134+
)
131135

132136

133137
if __name__ == "__main__":
134138
import doctest
135-
139+
136140
doctest.testmod()
137141
test_trace()

0 commit comments

Comments
 (0)