From 25e53b59ef1cc60ff2b8a5bf5a3403389b9105ac Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 16 Jan 2026 12:38:30 +0100 Subject: [PATCH 1/3] naively specialize diagonal pullback --- src/pullbacks/eigh.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 195539cf..11171685 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -68,6 +68,16 @@ function eigh_pullback!( end return ΔA end +function eigh_pullback!( + ΔA::Diagonal, A, DV, ΔDV, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eigh_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eigh_trunc_pullback!( @@ -141,6 +151,16 @@ function eigh_trunc_pullback!( end return ΔA end +function eigh_trunc_pullback!( + ΔA::Diagonal, A, DV, ΔDV; + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eigh_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eigh_vals_pullback!( From 519675524eb9d3d786a68205ecf6efb6e1baf876 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 16 Jan 2026 14:01:43 +0100 Subject: [PATCH 2/3] add eig diagonal pullback --- src/pullbacks/eig.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 4a203f64..6b89b64f 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -78,6 +78,16 @@ function eig_pullback!( end return ΔA end +function eig_pullback!( + ΔA::Diagonal, A, DV, ΔDV, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eig_trunc_pullback!( @@ -151,6 +161,16 @@ function eig_trunc_pullback!( end return ΔA end +function eig_trunc_pullback!( + ΔA::Diagonal, A, DV, ΔDV; + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eig_vals_pullback!( From 48fe1a9f17abb87263c1614eb45e5b219ed1a5b8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 18 Jan 2026 07:44:01 -0500 Subject: [PATCH 3/3] add svd diagonal pullback --- src/pullbacks/svd.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index a8f8b70c..1608343e 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -99,6 +99,17 @@ function svd_pullback!( end return ΔA end +function svd_pullback!( + ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ, ind = Colon(); + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = svd_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ svd_trunc_pullback!( @@ -201,6 +212,17 @@ function svd_trunc_pullback!( ΔA = mul!(ΔA, U, Y' * Ṽᴴ, 1, 1) return ΔA end +function svd_trunc_pullback!( + ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ; + rank_atol::Real = 0, + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = svd_trunc_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ; rank_atol, degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ svd_vals_pullback!(