From 68ecc19d8d2d09025f5d7f9b07362efac9bee5bf Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 17:17:27 +0100 Subject: [PATCH 1/6] guarantee hermitian results in polar --- src/implementations/polar.jl | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/implementations/polar.jl b/src/implementations/polar.jl index b94b2afa..de020dc6 100644 --- a/src/implementations/polar.jl +++ b/src/implementations/polar.jl @@ -53,7 +53,7 @@ function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD) if !isempty(P) S .= sqrt.(S) SsqrtVᴴ = lmul!(S, Vᴴ) - P = mul!(P, SsqrtVᴴ', SsqrtVᴴ) + P = _mul_herm!(P, SsqrtVᴴ) end return (W, P) end @@ -65,11 +65,24 @@ function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarViaSVD) if !isempty(P) S .= sqrt.(S) USsqrt = rmul!(U, S) - P = mul!(P, USsqrt, USsqrt') + P = _mul_herm!(P, USsqrt') end return (P, Wᴴ) end +# Implement `mul!(C, A', A)` and guarantee the result is hermitian. +# For BLAS calls that dispatch to `syrk` or `herk` this works automatically +# for GPU this currently does not seem to be guaranteed so we manually project +function _mul_herm!(C, A) + mul!(C, A', A) + project_hermitian!(C) + return C +end +function _mul_herm!(C::YALAPACK.BlasMat{T}, A::YALAPACK.BlasMat{T}) where {T} + mul!(C, A', A) + return C +end + # Implementation via Newton # -------------------------- function left_polar!(A::AbstractMatrix, WP, alg::PolarNewton) From d2dac21299f6dd65686087cc35fb859386f958c6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 17:20:52 +0100 Subject: [PATCH 2/6] update tests --- test/testsuite/polar.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/testsuite/polar.jl b/test/testsuite/polar.jl index c610ba34..321356ee 100644 --- a/test/testsuite/polar.jl +++ b/test/testsuite/polar.jl @@ -25,16 +25,14 @@ function test_left_polar( @test eltype(P) == eltype(A) && size(P) == (size(A, 2), size(A, 2)) @test W * P ≈ A @test isisometric(W) - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) + @test isposdef(P) W2, P2 = @testinferred left_polar!(Ac, (W, P), alg) @test W2 === W @test P2 === P @test W * P ≈ A @test isisometric(W) - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) + @test isposdef(P) noP = similar(P, (0, 0)) W2, P2 = @testinferred left_polar!(copy!(Ac, A), (W, noP), alg) @@ -62,16 +60,14 @@ function test_right_polar( @test eltype(P) == eltype(A) && size(P) == (size(A, 1), size(A, 1)) @test P * Wᴴ ≈ A @test isisometric(Wᴴ; side = :right) - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) + @test isposdef(P) P2, Wᴴ2 = @testinferred right_polar!(Ac, (P, Wᴴ), alg) @test P2 === P @test Wᴴ2 === Wᴴ @test P * Wᴴ ≈ A @test isisometric(Wᴴ; side = :right) - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) + @test isposdef(P) noP = similar(P, (0, 0)) P2, Wᴴ2 = @testinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg) From fb5b57fe2f01ed8d1c136d75e3252bac60f5a79a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 17:27:23 +0100 Subject: [PATCH 3/6] fix double `@testset` --- test/testsuite/polar.jl | 46 ++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/test/testsuite/polar.jl b/test/testsuite/polar.jl index 321356ee..d858c0c1 100644 --- a/test/testsuite/polar.jl +++ b/test/testsuite/polar.jl @@ -17,32 +17,30 @@ function test_left_polar( ) summary_str = testargs_summary(T, sz) return @testset "left_polar! algorithm $alg $summary_str" for alg in algs - @testset "algorithm $alg" for alg in algs - A = instantiate_matrix(T, sz) - Ac = deepcopy(A) - W, P = left_polar(A; alg) - @test eltype(W) == eltype(A) && size(W) == (size(A, 1), size(A, 2)) - @test eltype(P) == eltype(A) && size(P) == (size(A, 2), size(A, 2)) - @test W * P ≈ A - @test isisometric(W) - @test isposdef(P) + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + W, P = left_polar(A; alg) + @test eltype(W) == eltype(A) && size(W) == (size(A, 1), size(A, 2)) + @test eltype(P) == eltype(A) && size(P) == (size(A, 2), size(A, 2)) + @test W * P ≈ A + @test isisometric(W) + @test isposdef(P) - W2, P2 = @testinferred left_polar!(Ac, (W, P), alg) - @test W2 === W - @test P2 === P - @test W * P ≈ A - @test isisometric(W) - @test isposdef(P) + W2, P2 = @testinferred left_polar!(Ac, (W, P), alg) + @test W2 === W + @test P2 === P + @test W * P ≈ A + @test isisometric(W) + @test isposdef(P) - noP = similar(P, (0, 0)) - W2, P2 = @testinferred left_polar!(copy!(Ac, A), (W, noP), alg) - @test P2 === noP - @test W2 === W - @test isisometric(W) - P = W' * A # compute P explicitly to verify W correctness - @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) - @test isposdef(project_hermitian!(P)) - end + noP = similar(P, (0, 0)) + W2, P2 = @testinferred left_polar!(copy!(Ac, A), (W, noP), alg) + @test P2 === noP + @test W2 === W + @test isisometric(W) + P = W' * A # compute P explicitly to verify W correctness + @test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) + @test isposdef(project_hermitian!(P)) end end From 6848da3028c37b34796181da131b6fe0d671f0a6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 19:06:00 +0100 Subject: [PATCH 4/6] change to `syrk` convention --- src/implementations/polar.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/implementations/polar.jl b/src/implementations/polar.jl index de020dc6..ea9de7cf 100644 --- a/src/implementations/polar.jl +++ b/src/implementations/polar.jl @@ -53,7 +53,7 @@ function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD) if !isempty(P) S .= sqrt.(S) SsqrtVᴴ = lmul!(S, Vᴴ) - P = _mul_herm!(P, SsqrtVᴴ) + P = _mul_herm!(P, SsqrtVᴴ') end return (W, P) end @@ -65,7 +65,7 @@ function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarViaSVD) if !isempty(P) S .= sqrt.(S) USsqrt = rmul!(U, S) - P = _mul_herm!(P, USsqrt') + P = _mul_herm!(P, USsqrt) end return (P, Wᴴ) end @@ -74,12 +74,12 @@ end # For BLAS calls that dispatch to `syrk` or `herk` this works automatically # for GPU this currently does not seem to be guaranteed so we manually project function _mul_herm!(C, A) - mul!(C, A', A) + mul!(C, A, A') project_hermitian!(C) return C end function _mul_herm!(C::YALAPACK.BlasMat{T}, A::YALAPACK.BlasMat{T}) where {T} - mul!(C, A', A) + mul!(C, A, A') return C end From d53b792c8bb64a977949a2e41c32f50cc6ebd566 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 19:43:16 +0100 Subject: [PATCH 5/6] update changelog --- docs/src/changelog.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/changelog.md b/docs/src/changelog.md index 8ce4f956..ac4c8438 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -30,6 +30,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec ### Fixed +- Polar decompositions return exact hermitian factors ([#143](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/143) + ## [0.6.1](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/compare/v0.6.0...v0.6.1) - 2025-12-28 ### Added From 02f3366bc14106ca8603dee32f95b8c8359f245a Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 5 Jan 2026 14:12:07 -0500 Subject: [PATCH 6/6] some reshuffling to actually correctly dispatch --- .../MatrixAlgebraKitAMDGPUExt.jl | 8 ++++++++ ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 8 ++++++++ src/implementations/polar.jl | 2 +- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index ff150f24..fd0c1604 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -198,4 +198,12 @@ function MatrixAlgebraKit.truncate( return Vᴴ[ind, :], ind end +# avoids calling the BlasMat specialization that assumes syrk! or herk! is called +# TODO: remove once syrk! or herk! is defined +function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix{T}) where {T <: BlasFloat} + mul!(C, A, A') + project_hermitian!(C) + return C +end + end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 4d34dd9e..e3acb553 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -183,4 +183,12 @@ function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix) return A, B end +# avoids calling the BlasMat specialization that assumes syrk! or herk! is called +# TODO: remove once syrk! or herk! is defined +function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T}) where {T <: BlasFloat} + mul!(C, A, A') + project_hermitian!(C) + return C +end + end diff --git a/src/implementations/polar.jl b/src/implementations/polar.jl index ea9de7cf..00f9cbbd 100644 --- a/src/implementations/polar.jl +++ b/src/implementations/polar.jl @@ -78,7 +78,7 @@ function _mul_herm!(C, A) project_hermitian!(C) return C end -function _mul_herm!(C::YALAPACK.BlasMat{T}, A::YALAPACK.BlasMat{T}) where {T} +function _mul_herm!(C::YALAPACK.BlasMat{T}, A::YALAPACK.BlasMat{T}) where {T <: YALAPACK.BlasFloat} mul!(C, A, A') return C end