Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ See `methods(as)` for all the constructors, `?as` for their documentation.
as
```

Transforms which produce `NamedTuple`s can be `merge`d, which internally calls `Base.merge`; name collisions will thus follow `Base` behavior, which is that the right-most instance will be kept.
When using e.g. [`ConstructionBase.setproperties`](https://juliaobjects.github.io/ConstructionBase.jl/stable/#ConstructionBase.setproperties) to map a vector onto a subset of parameters stored in a struct, this functionality allows transforms for different parameter subsets to be constructed for use separately or together:

```julia
t_a = as((;a = asℝ₊))
t_b = as((;b = as𝕀))
t_c = as((;c = TVShift(5) ∘ TVExp()))
t_ab = merge(t_a, t_b)
t_abc = merge(t_ab, t_c)
t_abc = merge(t_a, t_b, t_c)
t_collision = merge(t_a, as((;a = asℝ₋))) # Will have a = asℝ₋, from rightmost
```

## Scalar transforms

The symbol `∞` is a placeholder for infinity. It does not correspond to `Inf`, but acts as a placeholder for the correct dispatch. `-∞` is valid.
Expand Down
9 changes: 9 additions & 0 deletions src/aggregation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,15 @@ Base.getindex(t::TransformTuple, i::Int) = getindex(_inner(t), i)
Base.propertynames(t::TransformTuple) = propertynames(_inner(t))
Base.getproperty(t::TransformTuple, i::Int) = getproperty(_inner(t), i)
Base.getproperty(t::TransformTuple{<:NamedTuple}, i::Symbol) = getproperty(_inner(t), i)
"""
$(SIGNATURES)

Merge multiple `TransformTuple{<:NamedTuple}` by merging the underlying `NamedTuple`s.
"""
function Base.merge(t1::TransformTuple{<:NamedTuple},
ts::Vararg{TransformTuple{<:NamedTuple}})
TransformTuple(merge(_inner(t1), map(_inner, ts)...))
end

function _summary_rows(transformation::TransformTuple, mime)
inner = _inner(transformation)
Expand Down
25 changes: 25 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,31 @@ end
@test_throws ArgumentError("Property :b not in (:a, :c).") inverse(t, (a = 1.0, c = 2.0))
end

@testset "merging NamedTuple" begin
t1 = as((a = asℝ, b = as𝕀))
t2 = as((c = CorrCholeskyFactor(3), d = unit_vector_norm(4)))
t3 = as((e = asℝ₊, f = as𝕀))
tm = @inferred(merge(t1))
@test tm == t1
tm = @inferred(merge(t1, t2))
@test tm == as((a = asℝ, b = as𝕀, c = CorrCholeskyFactor(3), d = unit_vector_norm(4)))
tm = @inferred(merge(t1, t2, t3))
@test tm == as((a = asℝ, b = as𝕀, c = CorrCholeskyFactor(3), d = unit_vector_norm(4),
e = asℝ₊, f = as𝕀))
x = randn(dimension(tm))
y = transform(tm, x)
x′ = inverse(tm, y)
@test x ≈ x′
# Check merge collision behavior: rightmost gets kept
t4 = as((b = asℝ₋, c = TVScale(2.0)))
tm = @inferred(merge(t1, t4))
@test tm == as((a = asℝ, b = asℝ₋, c = TVScale(2.0)))
@test tm != as((a = asℝ, b = as𝕀, c = TVScale(2.0)))
tm = @inferred(merge(t1, t4, t2))
@test tm == as((a = asℝ, b = asℝ₋, c = CorrCholeskyFactor(3), d = unit_vector_norm(4)))

end

####
#### log density correctness checks
####
Expand Down
Loading