diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 64acb509..9aab7a8e 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -443,7 +443,9 @@ left_orth_alg(alg::LeftOrthAlgorithm) = alg left_orth_alg(alg::QRAlgorithms) = LeftOrthViaQR(alg) left_orth_alg(alg::PolarAlgorithms) = LeftOrthViaPolar(alg) left_orth_alg(alg::SVDAlgorithms) = LeftOrthViaSVD(alg) +left_orth_alg(alg::DiagonalAlgorithm) = LeftOrthViaQR(alg) left_orth_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftOrthViaSVD(alg) +left_orth_alg(alg::TruncatedAlgorithm{DiagonalAlgorithm}) = LeftOrthViaSVD(alg) """ right_orth_alg(alg::AbstractAlgorithm) -> RightOrthAlgorithm @@ -478,7 +480,9 @@ right_orth_alg(alg::RightOrthAlgorithm) = alg right_orth_alg(alg::LQAlgorithms) = RightOrthViaLQ(alg) right_orth_alg(alg::PolarAlgorithms) = RightOrthViaPolar(alg) right_orth_alg(alg::SVDAlgorithms) = RightOrthViaSVD(alg) +right_orth_alg(alg::DiagonalAlgorithm) = RightOrthViaLQ(alg) right_orth_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = RightOrthViaSVD(alg) +right_orth_alg(alg::TruncatedAlgorithm{DiagonalAlgorithm}) = RightOrthViaSVD(alg) """ left_null_alg(alg::AbstractAlgorithm) -> LeftNullAlgorithm diff --git a/test/orthnull.jl b/test/orthnull.jl index dec12946..084c0a9b 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -19,16 +19,16 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63) if T ∈ BLASFloats if CUDA.functional() TestSuite.test_orthnull(CuMatrix{T}, (m, n); test_nullity = false) - n == m && TestSuite.test_orthnull(Diagonal{T, CuVector{T}}, m; test_orthnull = false) + n == m && TestSuite.test_orthnull(Diagonal{T, CuVector{T}}, m) end if AMDGPU.functional() TestSuite.test_orthnull(ROCMatrix{T}, (m, n); test_nullity = false) - n == m && TestSuite.test_orthnull(Diagonal{T, ROCVector{T}}, m; test_orthnull = false) + n == m && TestSuite.test_orthnull(Diagonal{T, ROCVector{T}}, m) end end if !is_buildkite TestSuite.test_orthnull(T, (m, n)) AT = Diagonal{T, Vector{T}} - TestSuite.test_orthnull(AT, m; test_orthnull = false) + TestSuite.test_orthnull(AT, m) end end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 2f3fde50..dfc94244 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -72,9 +72,11 @@ is_pivoted(alg::MatrixAlgebraKit.LQViaTransposedQR) = is_pivoted(alg.qr_alg) isleftcomplete(V, N) = V * V' + N * N' ≈ I isleftcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isleftcomplete(collect(V), collect(N)) isleftcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isleftcomplete(collect(V), collect(N)) +isleftcomplete(V::Diagonal{T, <:ROCVector{<:T}}, N::Diagonal{T, <:ROCVector{<:T}}) where {T} = isleftcomplete(collect(V), collect(N)) isrightcomplete(Vᴴ, Nᴴ) = Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I isrightcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isrightcomplete(collect(V), collect(N)) isrightcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isrightcomplete(collect(V), collect(N)) +isrightcomplete(V::Diagonal{T, <:ROCVector{<:T}}, N::Diagonal{T, <:ROCVector{<:T}}) where {T} = isrightcomplete(collect(V), collect(N)) instantiate_unitary(T, A, sz) = qr_compact(randn!(similar(A, eltype(T), sz, sz)))[1] # AMDGPU can't generate ComplexF32 random numbers diff --git a/test/testsuite/orthnull.jl b/test/testsuite/orthnull.jl index 79349d2a..16b86c82 100644 --- a/test/testsuite/orthnull.jl +++ b/test/testsuite/orthnull.jl @@ -17,12 +17,12 @@ _right_orth_lq!(x, CVᴴ; kwargs...) = right_orth!(x, CVᴴ; alg = :lq, kwargs.. _right_orth_polar(x; kwargs...) = right_orth(x; alg = :polar, kwargs...) _right_orth_polar!(x, CVᴴ; kwargs...) = right_orth!(x, CVᴴ; alg = :polar, kwargs...) -function test_orthnull(T::Type, sz; test_nullity = true, test_orthnull = true, kwargs...) +function test_orthnull(T::Type, sz; test_nullity = true, kwargs...) summary_str = testargs_summary(T, sz) return @testset "orthnull $summary_str" begin - test_orthnull && test_left_orthnull(T, sz; kwargs...) + test_left_orthnull(T, sz; kwargs...) test_nullity && test_left_nullity(T, sz; kwargs...) - test_orthnull && test_right_orthnull(T, sz; kwargs...) + test_right_orthnull(T, sz; kwargs...) test_nullity && test_right_nullity(T, sz; kwargs...) end end @@ -276,9 +276,9 @@ function test_right_orthnull( # passing an algorithm C, Vᴴ = @testinferred right_orth(A; alg = MatrixAlgebraKit.default_lq_algorithm(A)) - Nᴴ = @testinferred right_null(A; alg = :lq, positive = true) @test C isa typeof(A) && size(C) == (m, minmn) @test Vᴴ isa typeof(A) && size(Vᴴ) == (minmn, n) + Nᴴ = @testinferred right_null(A; alg = :lq, positive = true) @test eltype(Nᴴ) == eltype(A) && size(Nᴴ) == (n - minmn, n) @test C * Vᴴ ≈ A @test isisometric(Vᴴ; side = :right)