Skip to content

Commit 6205411

Browse files
Matrix trace calculation
1 parent 249e64e commit 6205411

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

linear_algebra/matrix_trace.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
Matrix trace calculation.
3+
4+
The trace of a square matrix is the sum of the elements on the main diagonal.
5+
It's an important linear algebra operation with many applications.
6+
7+
Reference: https://en.wikipedia.org/wiki/Trace_(linear_algebra)
8+
"""
9+
10+
import numpy as np
11+
from numpy import float64
12+
from numpy.typing import NDArray
13+
14+
15+
def trace(matrix: NDArray[float64]) -> float:
16+
"""
17+
Calculate the trace of a square matrix.
18+
19+
The trace is the sum of the diagonal elements of a square matrix.
20+
21+
Parameters:
22+
matrix (NDArray[float64]): A square matrix
23+
24+
Returns:
25+
float: The trace of the matrix
26+
27+
Raises:
28+
ValueError: If the matrix is not square
29+
30+
Examples:
31+
>>> import numpy as np
32+
>>> matrix = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float)
33+
>>> trace(matrix)
34+
5.0
35+
36+
>>> matrix = np.array([[2.0, -1.0, 3.0], [4.0, 5.0, -2.0], [1.0, 0.0, 7.0]], dtype=float)
37+
>>> trace(matrix)
38+
14.0
39+
40+
>>> matrix = np.array([[5.0]], dtype=float)
41+
>>> trace(matrix)
42+
5.0
43+
"""
44+
if matrix.shape[0] != matrix.shape[1]:
45+
raise ValueError("Matrix must be square")
46+
47+
return float(np.sum(np.diag(matrix)))
48+
49+
50+
def trace_properties_demo(matrix: NDArray[float64]) -> dict:
51+
"""
52+
Demonstrate various properties of the trace operation.
53+
54+
Parameters:
55+
matrix (NDArray[float64]): A square matrix
56+
57+
Returns:
58+
dict: Dictionary containing trace properties and calculations
59+
"""
60+
if matrix.shape[0] != matrix.shape[1]:
61+
raise ValueError("Matrix must be square")
62+
63+
n = matrix.shape[0]
64+
65+
# Calculate trace
66+
tr = trace(matrix)
67+
68+
# Calculate transpose trace (should be equal to original)
69+
tr_transpose = trace(matrix.T)
70+
71+
# Calculate trace of scalar multiple
72+
scalar = 2.0
73+
tr_scalar = trace(scalar * matrix)
74+
75+
# Create identity matrix for comparison
76+
identity = np.eye(n, dtype=float64)
77+
tr_identity = trace(identity)
78+
79+
return {
80+
"original_trace": tr,
81+
"transpose_trace": tr_transpose,
82+
"scalar_multiple_trace": tr_scalar,
83+
"scalar_factor": scalar,
84+
"identity_trace": tr_identity,
85+
"trace_equals_transpose": abs(tr - tr_transpose) < 1e-10,
86+
"scalar_property_check": abs(tr_scalar - scalar * tr) < 1e-10
87+
}
88+
89+
90+
def test_trace() -> None:
91+
"""
92+
Test function for matrix trace calculation.
93+
94+
>>> test_trace() # self running tests
95+
"""
96+
# Test 1: 2x2 matrix
97+
matrix_2x2 = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float)
98+
tr_2x2 = trace(matrix_2x2)
99+
assert abs(tr_2x2 - 5.0) < 1e-10, "2x2 trace calculation failed"
100+
101+
# Test 2: 3x3 matrix
102+
matrix_3x3 = np.array([[2.0, -1.0, 3.0],
103+
[4.0, 5.0, -2.0],
104+
[1.0, 0.0, 7.0]], dtype=float)
105+
tr_3x3 = trace(matrix_3x3)
106+
assert abs(tr_3x3 - 14.0) < 1e-10, "3x3 trace calculation failed"
107+
108+
# Test 3: Identity matrix
109+
identity_4x4 = np.eye(4, dtype=float)
110+
tr_identity = trace(identity_4x4)
111+
assert abs(tr_identity - 4.0) < 1e-10, "Identity matrix trace should equal dimension"
112+
113+
# Test 4: Zero matrix
114+
zero_matrix = np.zeros((3, 3), dtype=float)
115+
tr_zero = trace(zero_matrix)
116+
assert abs(tr_zero) < 1e-10, "Zero matrix should have zero trace"
117+
118+
# Test 5: Trace properties
119+
test_matrix = np.array([[1.0, 2.0, 3.0],
120+
[4.0, 5.0, 6.0],
121+
[7.0, 8.0, 9.0]], dtype=float)
122+
properties = trace_properties_demo(test_matrix)
123+
assert properties["trace_equals_transpose"], "Trace should equal transpose trace"
124+
assert properties["scalar_property_check"], "Scalar multiplication property failed"
125+
126+
# Test 6: Diagonal matrix
127+
diagonal_matrix = np.diag([1.0, 2.0, 3.0, 4.0])
128+
tr_diagonal = trace(diagonal_matrix)
129+
expected = 1.0 + 2.0 + 3.0 + 4.0
130+
assert abs(tr_diagonal - expected) < 1e-10, "Diagonal matrix trace should equal sum of diagonal elements"
131+
132+
133+
if __name__ == "__main__":
134+
import doctest
135+
136+
doctest.testmod()
137+
test_trace()

0 commit comments

Comments
 (0)