Skip to content

Commit d89db3a

Browse files
authored
[Feature] reindexdims as a generalized permutedims (#23)
* add `reindexdims` * add some small tests
1 parent c6e4fc3 commit d89db3a

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed

src/SparseArrayKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using TupleTools
77

88
export SparseArray
99
export nonzero_pairs, nonzero_keys, nonzero_values, nonzero_length
10+
export reindexdims, reindexdims!
1011

1112
include("sparsearray.jl")
1213
include("base.jl")

src/base.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,27 @@ function Base.reshape(parent::SparseArray{T}, dims::Dims) where {T}
6969
end
7070
return child
7171
end
72+
73+
@doc """
74+
reindexdims(A, p)
75+
reindexdims!(C, A, p)
76+
77+
Reindex the dimensions (axes) of array `A`. `p` is a tuple of integers specifying which indices are selected.
78+
This is similar to `permutedims(!)`, but also allows both repeated and omitted integers.
79+
The former boils down to a broadcasting along the diagonal, i.e. `C[i, i, j, k, ...] = A[i, j, k, ...]`,
80+
while the latter signifies a reduction over the omitted index, i.e. `C[j, k, ...] = ∑_i A[i, j, k, ...]`.
81+
""" reindexdims, reindexdims!
82+
83+
function reindexdims(A::SparseArray, p::IndexTuple)
84+
C = similar(A, TupleTools.getindices(size(A), p))
85+
return reindexdims!(C, A, p)
86+
end
87+
function reindexdims!(C::SparseArray{T, N}, A::SparseArray, p::IndexTuple{N}) where {T, N}
88+
_zero!(C)
89+
_sizehint!(C, nonzero_length(A))
90+
for (IA, vA) in nonzero_pairs(A)
91+
IC = CartesianIndex(TupleTools.getindices(IA.I, p))
92+
increaseindex!(C, vA, IC)
93+
end
94+
return C
95+
end

test/basic.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module BasicTests
2+
23
using SparseArrayKit
34
using Test, TestExtras, LinearAlgebra, Random
5+
using TupleTools
46

57
#=
68
generate a whole bunch of random contractions, compare with the dense result
@@ -96,4 +98,19 @@ end
9698
@test size(SparseArray(I, 3, 8)) == (3, 8)
9799
end
98100

101+
@timedtestset "Index manipulations" begin
102+
dims = (2, 3, 4)
103+
A = randn_sparse(Float64, dims)
104+
@test @constinferred(reindexdims(A, (2, 1, 3))) == permutedims(A, (2, 1, 3))
105+
106+
A_expanded = @constinferred reindexdims(A, (1, 1, 2, 3))
107+
@test size(A_expanded) == TupleTools.getindices(size(A), (1, 1, 2, 3))
108+
@test norm(A_expanded) norm(A)
109+
@test reindexdims(A_expanded, (1, 3, 4)) == A
110+
111+
A_reduced = @constinferred reindexdims(A, (1, 2))
112+
@test size(A_reduced) == TupleTools.getindices(size(A), (1, 2))
113+
@test Array(A_reduced) sum(Array(A); dims = 3)
114+
end
115+
99116
end

0 commit comments

Comments
 (0)