Skip to content

Commit 1dd58fe

Browse files
line length error resolve
1 parent 4f1214e commit 1dd58fe

File tree

1 file changed

+34
-31
lines changed

1 file changed

+34
-31
lines changed

linear_algebra/matrix_trace.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def trace(matrix: NDArray[float64]) -> float:
1616
"""
17-
Calculate the trace of a square matrix.
17+
Calculate the trace of a square matrix.
1818
The trace is the sum of the diagonal elements of a square matrix.
1919
Parameters:
2020
matrix (NDArray[float64]): A square matrix
@@ -27,106 +27,109 @@ def trace(matrix: NDArray[float64]) -> float:
2727
>>> matrix = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float)
2828
>>> trace(matrix)
2929
5.0
30-
>>> matrix = np.array([[2.0, -1.0, 3.0], [4.0, 5.0, -2.0], [1.0, 0.0, 7.0]], dtype=float)
30+
31+
>>> matrix = np.array(
32+
... [[2.0, -1.0, 3.0], [4.0, 5.0, -2.0], [1.0, 0.0, 7.0]], dtype=float
33+
... )
3134
>>> trace(matrix)
3235
14.0
36+
3337
>>> matrix = np.array([[5.0]], dtype=float)
3438
>>> trace(matrix)
3539
5.0
3640
"""
3741
if matrix.shape[0] != matrix.shape[1]:
3842
raise ValueError("Matrix must be square")
39-
43+
4044
return float(np.sum(np.diag(matrix)))
4145

4246

4347
def trace_properties_demo(matrix: NDArray[float64]) -> dict:
4448
"""
4549
Demonstrate various properties of the trace operation.
4650
Parameters:
47-
matrix (NDArray[float64]): A square matrix
51+
matrix (NDArray[float64]): A square matrix
4852
Returns:
4953
dict: Dictionary containing trace properties and calculations
5054
"""
5155
if matrix.shape[0] != matrix.shape[1]:
5256
raise ValueError("Matrix must be square")
53-
57+
5458
n = matrix.shape[0]
55-
59+
5660
# Calculate trace
5761
tr = trace(matrix)
58-
62+
5963
# Calculate transpose trace (should be equal to original)
6064
tr_transpose = trace(matrix.T)
61-
65+
6266
# Calculate trace of scalar multiple
6367
scalar = 2.0
6468
tr_scalar = trace(scalar * matrix)
65-
69+
6670
# Create identity matrix for comparison
6771
identity = np.eye(n, dtype=float64)
6872
tr_identity = trace(identity)
69-
73+
7074
return {
7175
"original_trace": tr,
7276
"transpose_trace": tr_transpose,
7377
"scalar_multiple_trace": tr_scalar,
7478
"scalar_factor": scalar,
7579
"identity_trace": tr_identity,
7680
"trace_equals_transpose": abs(tr - tr_transpose) < 1e-10,
77-
"scalar_property_check": abs(tr_scalar - scalar * tr) < 1e-10,
81+
"scalar_property_check": abs(tr_scalar - scalar * tr) < 1e-10
7882
}
7983

8084

8185
def test_trace() -> None:
8286
"""
83-
Test function for matrix trace calculation.
84-
87+
Test function for matrix trace calculation.
8588
>>> test_trace() # self running tests
8689
"""
8790
# Test 1: 2x2 matrix
8891
matrix_2x2 = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float)
8992
tr_2x2 = trace(matrix_2x2)
9093
assert abs(tr_2x2 - 5.0) < 1e-10, "2x2 trace calculation failed"
91-
94+
9295
# Test 2: 3x3 matrix
93-
matrix_3x3 = np.array(
94-
[[2.0, -1.0, 3.0], [4.0, 5.0, -2.0], [1.0, 0.0, 7.0]], dtype=float
95-
)
96+
matrix_3x3 = np.array([[2.0, -1.0, 3.0],
97+
[4.0, 5.0, -2.0],
98+
[1.0, 0.0, 7.0]], dtype=float)
9699
tr_3x3 = trace(matrix_3x3)
97100
assert abs(tr_3x3 - 14.0) < 1e-10, "3x3 trace calculation failed"
98-
101+
99102
# Test 3: Identity matrix
100103
identity_4x4 = np.eye(4, dtype=float)
101104
tr_identity = trace(identity_4x4)
102-
assert abs(tr_identity - 4.0) < 1e-10, (
103-
"Identity matrix trace should equal dimension"
104-
)
105-
105+
assert (
106+
abs(tr_identity - 4.0) < 1e-10
107+
), "Identity matrix trace should equal dimension"
108+
106109
# Test 4: Zero matrix
107110
zero_matrix = np.zeros((3, 3), dtype=float)
108111
tr_zero = trace(zero_matrix)
109112
assert abs(tr_zero) < 1e-10, "Zero matrix should have zero trace"
110-
113+
111114
# Test 5: Trace properties
112-
test_matrix = np.array(
113-
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=float
114-
)
115+
test_matrix = np.array([[1.0, 2.0, 3.0],
116+
[4.0, 5.0, 6.0],
117+
[7.0, 8.0, 9.0]], dtype=float)
115118
properties = trace_properties_demo(test_matrix)
116119
assert properties["trace_equals_transpose"], "Trace should equal transpose trace"
117120
assert properties["scalar_property_check"], "Scalar multiplication property failed"
118-
121+
119122
# Test 6: Diagonal matrix
120123
diagonal_matrix = np.diag([1.0, 2.0, 3.0, 4.0])
121124
tr_diagonal = trace(diagonal_matrix)
122125
expected = 1.0 + 2.0 + 3.0 + 4.0
123-
assert abs(tr_diagonal - expected) < 1e-10, (
124-
"Diagonal matrix trace should equal sum of diagonal elements"
125-
)
126+
assert (
127+
abs(tr_diagonal - expected) < 1e-10
128+
), "Diagonal matrix trace should equal sum of diagonal elements"
126129

127130

128131
if __name__ == "__main__":
129132
import doctest
130-
133+
131134
doctest.testmod()
132135
test_trace()

0 commit comments

Comments
 (0)