Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions src/factorizations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,23 @@ function _sort_and_perm(values::SectorVector; by = identity, rev::Bool = false)
return values_sorted, perms
end

function _findtruncvalue_order(values::SectorVector, n::Int; by = identity, rev::Bool = false)
I = sectortype(values)
p = sortperm(parent(values); by, rev)

if FusionStyle(I) isa UniqueFusion # dimensions are all 1
return n <= 0 ? nothing : p[min(n, length(p))]
else
dims = similar(values, Base.promote_op(dim, I))
for (c, v) in pairs(dims)
fill!(v, dim(c))
end
cumulative_dim = cumsum(Base.permute!(parent(dims), p))
k = findlast(<=(n), cumulative_dim)
return isnothing(k) ? k : p[k]
end
end

# findtruncated
# -------------
# Generic fallback
Expand All @@ -202,25 +219,25 @@ function MAK.findtruncated(values::SectorVector, ::NoTruncation)
return SectorDict(c => Colon() for c in keys(values))
end

# TruncationByOrder strategy:
# - find the howmany'th value of the input sorted according to the strategy
# - discard everything that is ordered after that value

function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder)
values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev)
inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany))
return SectorDict(c => perms[c][I] for (c, I) in inds)
end
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder)
I = keytype(values)
truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values))
totaldim = sum(dim(c) * d for (c, d) in truncdim; init = 0)
while totaldim > strategy.howmany
next = _findnexttruncvalue(values, truncdim; strategy.by, strategy.rev)
isnothing(next) && break
_, cmin = next
truncdim[cmin] -= 1
totaldim -= dim(cmin)
truncdim[cmin] == 0 && delete!(truncdim, cmin)
k = _findtruncvalue_order(values, strategy.howmany; strategy.by, strategy.rev)

if isnothing(k)
# discard everything
return SectorDict{sectortype(values), UnitRange{Int}}()
else
val = strategy.by(values[k])
strategy = trunctol(; atol = val, strategy.by, keep_below = !strategy.rev)
return MAK.findtruncated_svd(values, strategy)
end
return SectorDict(c => Base.OneTo(d) for (c, d) in truncdim)
end
# disambiguate
MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) =
MAK.findtruncated(values, strategy)

function MAK.findtruncated(values::SectorVector, strategy::TruncationByFilter)
return SectorDict(c => findall(strategy.filter, d) for (c, d) in pairs(values))
Expand Down
Loading