From d0d7c657bb3b0f88bef839bc3bba40f90c116eb4 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 5 Jan 2026 15:52:31 +0100 Subject: [PATCH 01/19] Test truncated methods with GPU arrays --- test/eig.jl | 8 ++++---- test/eigh.jl | 16 ++++++++-------- test/testsuite/eig.jl | 8 ++++---- test/testsuite/eigh.jl | 8 ++++---- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/test/eig.jl b/test/eig.jl index 7cc54c5d..df6fdf86 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -19,10 +19,10 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(123) if T ∈ BLASFloats if CUDA.functional() - TestSuite.test_eig(CuMatrix{T}, (m, m); test_trunc = false) - TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),); test_trunc = false) - TestSuite.test_eig(Diagonal{T, CuVector{T}}, m; test_trunc = false) - TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) + TestSuite.test_eig(CuMatrix{T}, (m, m)) + TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),)) + TestSuite.test_eig(Diagonal{T, CuVector{T}}, m) + TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) end #= not yet supported if AMDGPU.functional() diff --git a/test/eigh.jl b/test/eigh.jl index 8766ccc0..1266a1a2 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -22,10 +22,10 @@ for T in (BLASFloats..., GenericFloats...) CUSOLVER_Jacobi(), CUSOLVER_DivideAndConquer(), ) - TestSuite.test_eigh(CuMatrix{T}, (m, m); test_trunc = false) - TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS; test_trunc = false) - TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m; test_trunc = false) - TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) + TestSuite.test_eigh(CuMatrix{T}, (m, m)) + TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS) + TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m) + TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) end if AMDGPU.functional() ROCSOLVER_EIGH_ALGS = ( @@ -34,10 +34,10 @@ for T in (BLASFloats..., GenericFloats...) ROCSOLVER_QRIteration(), ROCSOLVER_Bisection(), ) - TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_trunc = false) - TestSuite.test_eigh_algs(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS; test_trunc = false) - TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m; test_trunc = false) - TestSuite.test_eigh_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) + TestSuite.test_eigh(ROCMatrix{T}, (m, m)) + TestSuite.test_eigh_algs(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS) + TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m) + TestSuite.test_eigh_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),)) end end if !is_buildkite diff --git a/test/testsuite/eig.jl b/test/testsuite/eig.jl index 2dbea8b9..9775162c 100644 --- a/test/testsuite/eig.jl +++ b/test/testsuite/eig.jl @@ -3,19 +3,19 @@ using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra: I using GenericSchur -function test_eig(T::Type, sz; test_trunc = true, kwargs...) +function test_eig(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "eig $summary_str" begin test_eig_full(T, sz; kwargs...) - test_trunc && test_eig_trunc(T, sz; kwargs...) + test_eig_trunc(T, sz; kwargs...) end end -function test_eig_algs(T::Type, sz, algs; test_trunc = true, kwargs...) +function test_eig_algs(T::Type, sz, algs; kwargs...) summary_str = testargs_summary(T, sz) return @testset "eig algorithms $summary_str" begin test_eig_full_algs(T, sz, algs; kwargs...) - test_trunc && test_eig_trunc_algs(T, sz, algs; kwargs...) + test_eig_trunc_algs(T, sz, algs; kwargs...) end end diff --git a/test/testsuite/eigh.jl b/test/testsuite/eigh.jl index df6e4d6e..8e513a1c 100644 --- a/test/testsuite/eigh.jl +++ b/test/testsuite/eigh.jl @@ -3,19 +3,19 @@ using GenericLinearAlgebra using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra: I, opnorm -function test_eigh(T::Type, sz; test_trunc = true, kwargs...) +function test_eigh(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "eigh $summary_str" begin test_eigh_full(T, sz; kwargs...) - test_trunc && test_eigh_trunc(T, sz; kwargs...) + test_eigh_trunc(T, sz; kwargs...) end end -function test_eigh_algs(T::Type, sz, algs; test_trunc = true, kwargs...) +function test_eigh_algs(T::Type, sz, algs; kwargs...) summary_str = testargs_summary(T, sz) return @testset "eigh algorithms $summary_str" begin test_eigh_full_algs(T, sz, algs; kwargs...) - test_trunc && test_eigh_trunc_algs(T, sz, algs; kwargs...) + test_eigh_trunc_algs(T, sz, algs; kwargs...) end end From 72a18b3a77303488be40090654a822019c7ad3bf Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 5 Jan 2026 16:02:42 +0100 Subject: [PATCH 02/19] Setup D0 --- test/testsuite/eig.jl | 4 +- test/testsuite/eigh.jl | 87 +++++++++++++++++++++--------------------- 2 files changed, 45 insertions(+), 46 deletions(-) diff --git a/test/testsuite/eig.jl b/test/testsuite/eig.jl index 9775162c..61ed1fc8 100644 --- a/test/testsuite/eig.jl +++ b/test/testsuite/eig.jl @@ -78,7 +78,7 @@ function test_eig_trunc( Ac = deepcopy(A) Tc = complex(eltype(T)) # eigenvalues are sorted by ascending real component... - D₀ = sort!(eig_vals(A); by = abs, rev = true) + D₀ = collect(sort!(eig_vals(A); by = abs, rev = true)) m = size(A, 1) rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) r = length(D₀) - rmin @@ -150,7 +150,7 @@ function test_eig_trunc_algs( Ac = deepcopy(A) Tc = complex(eltype(T)) # eigenvalues are sorted by ascending real component... - D₀ = sort!(eig_vals(A; alg); by = abs, rev = true) + D₀ = collect(sort!(eig_vals(A; alg); by = abs, rev = true)) m = size(A, 1) rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) r = length(D₀) - rmin diff --git a/test/testsuite/eigh.jl b/test/testsuite/eigh.jl index 8e513a1c..36c84c11 100644 --- a/test/testsuite/eigh.jl +++ b/test/testsuite/eigh.jl @@ -76,50 +76,49 @@ function test_eigh_trunc( A = A * A' A = project_hermitian!(A) Ac = deepcopy(A) - if !(T <: Diagonal) - m = size(A, 1) - D₀ = reverse(eigh_vals(A)) - r = m - 2 - s = 1 + sqrt(eps(real(eltype(T)))) - atol = sqrt(eps(real(eltype(T)))) - # truncrank - D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r)) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # trunctol - trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - #truncerror - s = 1 - sqrt(eps(real(eltype(T)))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(eltype(T)))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D4, V4 = @testinferred eigh_trunc_no_error(A; trunc) - @test length(diagview(D4)) == r - @test A * V4 ≈ V4 * D4 - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - @test V1 * (V1' * V3) ≈ V3 - @test V3 * (V3' * V1) ≈ V1 - @test V4 * (V4' * V1) ≈ V1 - end + m = size(A, 1) + D₀ = collect(reverse(eigh_vals(A))) + r = m - 2 + s = 1 + sqrt(eps(real(eltype(T)))) + atol = sqrt(eps(real(eltype(T)))) + # truncrank + D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r)) + @test length(diagview(D1)) == r + @test isisometric(V1) + @test A * V1 ≈ V1 * D1 + @test opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + # trunctol + trunc = trunctol(; atol = s * D₀[r + 1]) + D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc) + @test length(diagview(D2)) == r + @test isisometric(V2) + @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + #truncerror + s = 1 - sqrt(eps(real(eltype(T)))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc) + @test length(diagview(D3)) == r + @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + s = 1 - sqrt(eps(real(eltype(T)))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D4, V4 = @testinferred eigh_trunc_no_error(A; trunc) + @test length(diagview(D4)) == r + @test A * V4 ≈ V4 * D4 + + # test for same subspace + @test V1 * (V1' * V2) ≈ V2 + @test V2 * (V2' * V1) ≈ V1 + @test V1 * (V1' * V3) ≈ V3 + @test V3 * (V3' * V1) ≈ V1 + @test V4 * (V4' * V1) ≈ V1 + @testset "specify truncation algorithm" begin atol = sqrt(eps(real(eltype(T)))) m4 = 4 @@ -156,7 +155,7 @@ function test_eigh_trunc_algs( Ac = deepcopy(A) m = size(A, 1) - D₀ = reverse(eigh_vals(A)) + D₀ = collect(reverse(eigh_vals(A))) r = m - 2 s = 1 + sqrt(eps(real(eltype(T)))) # truncrank From e3dd5e85fcbb5bc6c4d91ccaddd101bd2350686f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 5 Jan 2026 16:16:34 +0100 Subject: [PATCH 03/19] Try avoiding partialsortperm --- src/implementations/truncation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 883c7759..72f16abc 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -49,7 +49,7 @@ findtruncated(values::AbstractVector, ::NoTruncation) = Colon() function findtruncated(values::AbstractVector, strategy::TruncationByOrder) howmany = min(strategy.howmany, length(values)) - return partialsortperm(values, 1:howmany; strategy.by, strategy.rev) + return sortperm(values; strategy.by, strategy.rev)[1:howmany] end function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder) strategy.by === abs || return findtruncated(values, strategy) From 30ae6a59a4d38c98464ba6145a5e6b8c7f938634 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 5 Jan 2026 17:00:50 +0100 Subject: [PATCH 04/19] Make _truncerror_impl more GPU friendly --- src/implementations/truncation.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 72f16abc..b4ce2fbf 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -96,14 +96,8 @@ function _truncerr_impl(values::AbstractVector, I; atol::Real = 0, rtol::Real = # fast path to avoid checking all values ϵᵖ ≥ Nᵖ && return Base.OneTo(0) - truncerrᵖ = zero(real(eltype(values))) - rank = length(values) - for i in reverse(I) - truncerrᵖ += by(values[i]) - truncerrᵖ ≥ ϵᵖ && break - rank -= 1 - end - + truncerrᵖ_array = cumsum(map(by, values[reverse(I)])) + rank = length(values) - (findfirst(truncerrᵖ -> truncerrᵖ ≥ ϵᵖ, truncerrᵖ_array) - 1) return Base.OneTo(rank) end From 8cfd87c7cbc888d5796cfe96df72f63baebe9839 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 5 Jan 2026 20:13:34 +0100 Subject: [PATCH 05/19] Make eigh_vals/full sort for Diagonal --- src/implementations/eigh.jl | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 40f2c557..54c6a20b 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -43,7 +43,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA @assert isdiag(A) m = size(A, 1) D, V = DV - @assert D isa Diagonal && V isa Diagonal + @assert D isa Diagonal @check_size(D, (m, m)) @check_scalar(D, A, real) @check_size(V, (m, m)) @@ -79,7 +79,7 @@ function initialize_output(::Union{typeof(eigh_trunc!), typeof(eigh_trunc_no_err end function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm) - return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A) + return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A, size(A)...) end function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorithm) return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) @@ -146,15 +146,26 @@ end function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) check_input(eigh_full!, A, DV, alg) D, V = DV - D === A || (diagview(D) .= real.(diagview(A))) + I = sortperm(real.(diagview(A))) + if D === A + sort!(diagview(A)) + else + diagview(D) .= real.(diagview(A))[I] + end one!(V) + Base.permutecols!!(V, I) return D, V end function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm) check_input(eigh_vals!, A, D, alg) Ad = diagview(A) - D === Ad || (D .= real.(Ad)) + if D === Ad + sort!(Ad) + else + I = sortperm(real.(Ad)) + D .= real.(Ad[I]) + end return D end From b8999cb7bc20e3ac94a2a980fed82d8f1cb3b2d0 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 07:09:38 +0100 Subject: [PATCH 06/19] Update src/implementations/eigh.jl Co-authored-by: Lukas Devos --- src/implementations/eigh.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 54c6a20b..b4573298 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -146,7 +146,7 @@ end function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) check_input(eigh_full!, A, DV, alg) D, V = DV - I = sortperm(real.(diagview(A))) + I = sortperm(diagview(A); by = real) if D === A sort!(diagview(A)) else From 7da7e06a3255720bfb5a9f1050a228b482a05830 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 07:09:45 +0100 Subject: [PATCH 07/19] Update src/implementations/eigh.jl Co-authored-by: Lukas Devos --- src/implementations/eigh.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index b4573298..62b13ac8 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -163,8 +163,8 @@ function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm) if D === Ad sort!(Ad) else - I = sortperm(real.(Ad)) - D .= real.(Ad[I]) + D .= real.(Ad) + sort!(D) end return D end From 661677520769f4d62da5e15a1d61ac9eb354b9b9 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 07:09:53 +0100 Subject: [PATCH 08/19] Update src/implementations/truncation.jl Co-authored-by: Lukas Devos --- src/implementations/truncation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index b4ce2fbf..eae8df69 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -97,7 +97,7 @@ function _truncerr_impl(values::AbstractVector, I; atol::Real = 0, rtol::Real = ϵᵖ ≥ Nᵖ && return Base.OneTo(0) truncerrᵖ_array = cumsum(map(by, values[reverse(I)])) - rank = length(values) - (findfirst(truncerrᵖ -> truncerrᵖ ≥ ϵᵖ, truncerrᵖ_array) - 1) + rank = length(values) - (findfirst(≥(ϵᵖ), truncerrᵖ_array) - 1) return Base.OneTo(rank) end From 82178f7c426d5d0fd0990522a7f136cf07711a79 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 07:32:08 +0100 Subject: [PATCH 09/19] Try to permutecols in a GPU friendly way --- src/implementations/eigh.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 62b13ac8..f96de467 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -152,8 +152,9 @@ function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) else diagview(D) .= real.(diagview(A))[I] end - one!(V) - Base.permutecols!!(V, I) + zero!(V) + Is = [ix -> CartesianIndex(ix, I[ix]) for ix in 1:size(A, 1)] + V[Is] .= one(eltype(A)) return D, V end From 68906045adf8b12e6c1b07eb40f4b712437a1601 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 07:48:30 +0100 Subject: [PATCH 10/19] Dumb typo fix --- src/implementations/eigh.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index f96de467..8f639551 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -153,7 +153,7 @@ function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) diagview(D) .= real.(diagview(A))[I] end zero!(V) - Is = [ix -> CartesianIndex(ix, I[ix]) for ix in 1:size(A, 1)] + Is = [CartesianIndex(ix, I[ix]) for ix in 1:size(A, 1)] V[Is] .= one(eltype(A)) return D, V end From 7329aa63704bf9e84ac4021cc40a31dd7d812177 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 04:13:50 -0500 Subject: [PATCH 11/19] Working GPU col permute --- ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl | 7 +++++++ ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 7 +++++++ src/implementations/eigh.jl | 5 +++-- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index abfa6353..68ff46c4 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -167,4 +167,11 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix return C end +function MatrixAlgebraKit.permute_V_cols!(V, I::ROCVector{Int}) + I_ixs = ROCArray(collect(1:size(V, 1))) + c_ixs = map(CartesianIndex, I, I_ixs) + V[c_ixs] .= one(eltype(V)) + return V +end + end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index e3acb553..2ae7d3b2 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -191,4 +191,11 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T return C end +function MatrixAlgebraKit.permute_V_cols!(V, I::CuVector{Int}) + I_ixs = CuArray(collect(1:size(V, 1))) + c_ixs = map(CartesianIndex, I, I_ixs) + V[c_ixs] .= one(eltype(V)) + return V +end + end diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 8f639551..7b9e998b 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -141,6 +141,8 @@ function eigh_trunc_no_error!(A, DV, alg::TruncatedAlgorithm) return DVtrunc end +permute_V_cols!(V, I::Vector{Int}) = Base.permutecols!!(V, I) + # Diagonal logic # -------------- function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) @@ -153,8 +155,7 @@ function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) diagview(D) .= real.(diagview(A))[I] end zero!(V) - Is = [CartesianIndex(ix, I[ix]) for ix in 1:size(A, 1)] - V[Is] .= one(eltype(A)) + V = permute_V_cols!(V, I) return D, V end From 71cae210861e8f03b1e6136fd1efac3ad6bf75ef Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 04:49:02 -0500 Subject: [PATCH 12/19] Don't test trunc for AMD --- test/eigh.jl | 9 +++++---- test/testsuite/eigh.jl | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/test/eigh.jl b/test/eigh.jl index 1266a1a2..2efb4e15 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -34,10 +34,11 @@ for T in (BLASFloats..., GenericFloats...) ROCSOLVER_QRIteration(), ROCSOLVER_Bisection(), ) - TestSuite.test_eigh(ROCMatrix{T}, (m, m)) - TestSuite.test_eigh_algs(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS) - TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m) - TestSuite.test_eigh_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),)) + # see https://github.com/JuliaGPU/AMDGPU.jl/issues/837 + TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_trunc = false) + TestSuite.test_eigh_algs(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS; test_trunc = false) + TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m; test_trunc = false) + TestSuite.test_eigh_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) end end if !is_buildkite diff --git a/test/testsuite/eigh.jl b/test/testsuite/eigh.jl index 36c84c11..2d99f3d0 100644 --- a/test/testsuite/eigh.jl +++ b/test/testsuite/eigh.jl @@ -3,11 +3,11 @@ using GenericLinearAlgebra using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra: I, opnorm -function test_eigh(T::Type, sz; kwargs...) +function test_eigh(T::Type, sz; test_trunc = true, kwargs...) summary_str = testargs_summary(T, sz) return @testset "eigh $summary_str" begin test_eigh_full(T, sz; kwargs...) - test_eigh_trunc(T, sz; kwargs...) + test_trunc && test_eigh_trunc(T, sz; kwargs...) end end @@ -15,7 +15,7 @@ function test_eigh_algs(T::Type, sz, algs; kwargs...) summary_str = testargs_summary(T, sz) return @testset "eigh algorithms $summary_str" begin test_eigh_full_algs(T, sz, algs; kwargs...) - test_eigh_trunc_algs(T, sz, algs; kwargs...) + test_trunc && test_eigh_trunc_algs(T, sz, algs; kwargs...) end end From bf05fb22fb5dd2cf1341525d10ba89b0cc75743f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 05:01:12 -0500 Subject: [PATCH 13/19] Dumb typo again --- test/testsuite/eigh.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/testsuite/eigh.jl b/test/testsuite/eigh.jl index 2d99f3d0..087ee734 100644 --- a/test/testsuite/eigh.jl +++ b/test/testsuite/eigh.jl @@ -11,7 +11,7 @@ function test_eigh(T::Type, sz; test_trunc = true, kwargs...) end end -function test_eigh_algs(T::Type, sz, algs; kwargs...) +function test_eigh_algs(T::Type, sz, algs; test_trunc = true, kwargs...) summary_str = testargs_summary(T, sz) return @testset "eigh algorithms $summary_str" begin test_eigh_full_algs(T, sz, algs; kwargs...) From ad2a4957664be2f0680c4314233da800eb808141 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 05:41:25 -0500 Subject: [PATCH 14/19] Use existing implementation for column permute --- ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl | 7 ------- ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 7 ------- src/implementations/eigh.jl | 6 +++--- 3 files changed, 3 insertions(+), 17 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 68ff46c4..abfa6353 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -167,11 +167,4 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix return C end -function MatrixAlgebraKit.permute_V_cols!(V, I::ROCVector{Int}) - I_ixs = ROCArray(collect(1:size(V, 1))) - c_ixs = map(CartesianIndex, I, I_ixs) - V[c_ixs] .= one(eltype(V)) - return V -end - end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 2ae7d3b2..e3acb553 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -191,11 +191,4 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T return C end -function MatrixAlgebraKit.permute_V_cols!(V, I::CuVector{Int}) - I_ixs = CuArray(collect(1:size(V, 1))) - c_ixs = map(CartesianIndex, I, I_ixs) - V[c_ixs] .= one(eltype(V)) - return V -end - end diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 7b9e998b..7301b224 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -141,8 +141,6 @@ function eigh_trunc_no_error!(A, DV, alg::TruncatedAlgorithm) return DVtrunc end -permute_V_cols!(V, I::Vector{Int}) = Base.permutecols!!(V, I) - # Diagonal logic # -------------- function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) @@ -155,7 +153,9 @@ function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) diagview(D) .= real.(diagview(A))[I] end zero!(V) - V = permute_V_cols!(V, I) + n = size(A, 1) + I .+= (0:(n - 1)) .* n + V[I] .= Ref(one(eltype(V))) return D, V end From 6df94256cac700093c2c72369d0088226470b7c7 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 12:11:05 +0100 Subject: [PATCH 15/19] Update src/implementations/eigh.jl Co-authored-by: Jutho --- src/implementations/eigh.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 7301b224..3160fcb0 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -151,6 +151,12 @@ function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) sort!(diagview(A)) else diagview(D) .= real.(diagview(A))[I] + diagA = diagview(A) + I = sortperm(diagA; by = real) + if D === A + sort!(diagA) + else + diagview(D) .= real.(view(diagA, I)) end zero!(V) n = size(A, 1) From c71079b50746ba239ffea07e861a3ed568fcabee Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 06:17:19 -0500 Subject: [PATCH 16/19] Fix suggestions --- src/implementations/eigh.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 3160fcb0..7301b224 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -151,12 +151,6 @@ function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) sort!(diagview(A)) else diagview(D) .= real.(diagview(A))[I] - diagA = diagview(A) - I = sortperm(diagA; by = real) - if D === A - sort!(diagA) - else - diagview(D) .= real.(view(diagA, I)) end zero!(V) n = size(A, 1) From 0099fbfe94d2fa9c90ddb779632efb1307e3c8f8 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 06:18:16 -0500 Subject: [PATCH 17/19] Actually fix --- src/implementations/eigh.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 7301b224..decb8732 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -147,10 +147,12 @@ function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) check_input(eigh_full!, A, DV, alg) D, V = DV I = sortperm(diagview(A); by = real) + diagA = diagview(A) + I = sortperm(diagA; by = real) if D === A - sort!(diagview(A)) + sort!(diagA) else - diagview(D) .= real.(diagview(A))[I] + diagview(D) .= real.(view(diagA, I)) end zero!(V) n = size(A, 1) From 5a488a59e9200ad90cf509cfdf921ab3a77a5363 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 14:27:23 +0100 Subject: [PATCH 18/19] Use permute to avoid double sorting --- src/implementations/eigh.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index decb8732..19a190b1 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -146,11 +146,10 @@ end function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) check_input(eigh_full!, A, DV, alg) D, V = DV - I = sortperm(diagview(A); by = real) diagA = diagview(A) I = sortperm(diagA; by = real) if D === A - sort!(diagA) + permute!(diagA, I) else diagview(D) .= real.(view(diagA, I)) end From b4ce7518cbf14bfff3ccda6511ace6ba3ab16807 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 6 Jan 2026 14:57:50 +0100 Subject: [PATCH 19/19] Use cumsum in a slightly better way --- src/implementations/truncation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index eae8df69..be730bed 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -96,7 +96,7 @@ function _truncerr_impl(values::AbstractVector, I; atol::Real = 0, rtol::Real = # fast path to avoid checking all values ϵᵖ ≥ Nᵖ && return Base.OneTo(0) - truncerrᵖ_array = cumsum(map(by, values[reverse(I)])) + truncerrᵖ_array = cumsum(map(by, view(values, reverse(I)))) rank = length(values) - (findfirst(≥(ϵᵖ), truncerrᵖ_array) - 1) return Base.OneTo(rank) end