diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 4a203f64..6b89b64f 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -78,6 +78,16 @@ function eig_pullback!( end return ΔA end +function eig_pullback!( + ΔA::Diagonal, A, DV, ΔDV, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eig_trunc_pullback!( @@ -151,6 +161,16 @@ function eig_trunc_pullback!( end return ΔA end +function eig_trunc_pullback!( + ΔA::Diagonal, A, DV, ΔDV; + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eig_vals_pullback!( diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 195539cf..a9b7423c 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -68,6 +68,16 @@ function eigh_pullback!( end return ΔA end +function eigh_pullback!( + ΔA::Diagonal, A, DV, ΔDV, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eigh_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eigh_trunc_pullback!( @@ -141,6 +151,16 @@ function eigh_trunc_pullback!( end return ΔA end +function eigh_trunc_pullback!( + ΔA::AbstractMatrix, A, DV, ΔDV; + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eigh_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eigh_vals_pullback!(