Skip to content

Commit dd61898

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1dd58fe commit dd61898

File tree

1 file changed

+29
-29
lines changed

1 file changed

+29
-29
lines changed

linear_algebra/matrix_trace.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def trace(matrix: NDArray[float64]) -> float:
1616
"""
17-
Calculate the trace of a square matrix.
17+
Calculate the trace of a square matrix.
1818
The trace is the sum of the diagonal elements of a square matrix.
1919
Parameters:
2020
matrix (NDArray[float64]): A square matrix
@@ -40,96 +40,96 @@ def trace(matrix: NDArray[float64]) -> float:
4040
"""
4141
if matrix.shape[0] != matrix.shape[1]:
4242
raise ValueError("Matrix must be square")
43-
43+
4444
return float(np.sum(np.diag(matrix)))
4545

4646

4747
def trace_properties_demo(matrix: NDArray[float64]) -> dict:
4848
"""
4949
Demonstrate various properties of the trace operation.
5050
Parameters:
51-
matrix (NDArray[float64]): A square matrix
51+
matrix (NDArray[float64]): A square matrix
5252
Returns:
5353
dict: Dictionary containing trace properties and calculations
5454
"""
5555
if matrix.shape[0] != matrix.shape[1]:
5656
raise ValueError("Matrix must be square")
57-
57+
5858
n = matrix.shape[0]
59-
59+
6060
# Calculate trace
6161
tr = trace(matrix)
62-
62+
6363
# Calculate transpose trace (should be equal to original)
6464
tr_transpose = trace(matrix.T)
65-
65+
6666
# Calculate trace of scalar multiple
6767
scalar = 2.0
6868
tr_scalar = trace(scalar * matrix)
69-
69+
7070
# Create identity matrix for comparison
7171
identity = np.eye(n, dtype=float64)
7272
tr_identity = trace(identity)
73-
73+
7474
return {
7575
"original_trace": tr,
7676
"transpose_trace": tr_transpose,
7777
"scalar_multiple_trace": tr_scalar,
7878
"scalar_factor": scalar,
7979
"identity_trace": tr_identity,
8080
"trace_equals_transpose": abs(tr - tr_transpose) < 1e-10,
81-
"scalar_property_check": abs(tr_scalar - scalar * tr) < 1e-10
81+
"scalar_property_check": abs(tr_scalar - scalar * tr) < 1e-10,
8282
}
8383

8484

8585
def test_trace() -> None:
8686
"""
87-
Test function for matrix trace calculation.
87+
Test function for matrix trace calculation.
8888
>>> test_trace() # self running tests
8989
"""
9090
# Test 1: 2x2 matrix
9191
matrix_2x2 = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float)
9292
tr_2x2 = trace(matrix_2x2)
9393
assert abs(tr_2x2 - 5.0) < 1e-10, "2x2 trace calculation failed"
94-
94+
9595
# Test 2: 3x3 matrix
96-
matrix_3x3 = np.array([[2.0, -1.0, 3.0],
97-
[4.0, 5.0, -2.0],
98-
[1.0, 0.0, 7.0]], dtype=float)
96+
matrix_3x3 = np.array(
97+
[[2.0, -1.0, 3.0], [4.0, 5.0, -2.0], [1.0, 0.0, 7.0]], dtype=float
98+
)
9999
tr_3x3 = trace(matrix_3x3)
100100
assert abs(tr_3x3 - 14.0) < 1e-10, "3x3 trace calculation failed"
101-
101+
102102
# Test 3: Identity matrix
103103
identity_4x4 = np.eye(4, dtype=float)
104104
tr_identity = trace(identity_4x4)
105-
assert (
106-
abs(tr_identity - 4.0) < 1e-10
107-
), "Identity matrix trace should equal dimension"
108-
105+
assert abs(tr_identity - 4.0) < 1e-10, (
106+
"Identity matrix trace should equal dimension"
107+
)
108+
109109
# Test 4: Zero matrix
110110
zero_matrix = np.zeros((3, 3), dtype=float)
111111
tr_zero = trace(zero_matrix)
112112
assert abs(tr_zero) < 1e-10, "Zero matrix should have zero trace"
113-
113+
114114
# Test 5: Trace properties
115-
test_matrix = np.array([[1.0, 2.0, 3.0],
116-
[4.0, 5.0, 6.0],
117-
[7.0, 8.0, 9.0]], dtype=float)
115+
test_matrix = np.array(
116+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=float
117+
)
118118
properties = trace_properties_demo(test_matrix)
119119
assert properties["trace_equals_transpose"], "Trace should equal transpose trace"
120120
assert properties["scalar_property_check"], "Scalar multiplication property failed"
121-
121+
122122
# Test 6: Diagonal matrix
123123
diagonal_matrix = np.diag([1.0, 2.0, 3.0, 4.0])
124124
tr_diagonal = trace(diagonal_matrix)
125125
expected = 1.0 + 2.0 + 3.0 + 4.0
126-
assert (
127-
abs(tr_diagonal - expected) < 1e-10
128-
), "Diagonal matrix trace should equal sum of diagonal elements"
126+
assert abs(tr_diagonal - expected) < 1e-10, (
127+
"Diagonal matrix trace should equal sum of diagonal elements"
128+
)
129129

130130

131131
if __name__ == "__main__":
132132
import doctest
133-
133+
134134
doctest.testmod()
135135
test_trace()

0 commit comments

Comments
 (0)