diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index b9d060fec..3bc1ebf48 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -191,6 +191,23 @@ function _sort_and_perm(values::SectorVector; by = identity, rev::Bool = false) return values_sorted, perms end +function _findtruncvalue_order(values::SectorVector, n::Int; by = identity, rev::Bool = false) + I = sectortype(values) + p = sortperm(parent(values); by, rev) + + if FusionStyle(I) isa UniqueFusion # dimensions are all 1 + return n <= 0 ? nothing : p[min(n, length(p))] + else + dims = similar(values, Base.promote_op(dim, I)) + for (c, v) in pairs(dims) + fill!(v, dim(c)) + end + cumulative_dim = cumsum(Base.permute!(parent(dims), p)) + k = findlast(<=(n), cumulative_dim) + return isnothing(k) ? k : p[k] + end +end + # findtruncated # ------------- # Generic fallback @@ -202,25 +219,25 @@ function MAK.findtruncated(values::SectorVector, ::NoTruncation) return SectorDict(c => Colon() for c in keys(values)) end +# TruncationByOrder strategy: +# - find the howmany'th value of the input sorted according to the strategy +# - discard everything that is ordered after that value + function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder) - values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev) - inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany)) - return SectorDict(c => perms[c][I] for (c, I) in inds) -end -function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) - I = keytype(values) - truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values)) - totaldim = sum(dim(c) * d for (c, d) in truncdim; init = 0) - while totaldim > strategy.howmany - next = _findnexttruncvalue(values, truncdim; strategy.by, strategy.rev) - isnothing(next) && break - _, cmin = next - truncdim[cmin] -= 1 - totaldim -= dim(cmin) - truncdim[cmin] == 0 && delete!(truncdim, cmin) + k = _findtruncvalue_order(values, strategy.howmany; strategy.by, strategy.rev) + + if isnothing(k) + # discard everything + return SectorDict{sectortype(values), UnitRange{Int}}() + else + val = strategy.by(values[k]) + strategy = trunctol(; atol = val, strategy.by, keep_below = !strategy.rev) + return MAK.findtruncated_svd(values, strategy) end - return SectorDict(c => Base.OneTo(d) for (c, d) in truncdim) end +# disambiguate +MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) = + MAK.findtruncated(values, strategy) function MAK.findtruncated(values::SectorVector, strategy::TruncationByFilter) return SectorDict(c => findall(strategy.filter, d) for (c, d) in pairs(values))