diff --git a/docs/src/index.md b/docs/src/index.md index aa0aecea..5f300b37 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -85,6 +85,19 @@ See `methods(as)` for all the constructors, `?as` for their documentation. as ``` +Transforms which produce `NamedTuple`s can be `merge`d, which internally calls `Base.merge`; name collisions will thus follow `Base` behavior, which is that the right-most instance will be kept. +When using e.g. [`ConstructionBase.setproperties`](https://juliaobjects.github.io/ConstructionBase.jl/stable/#ConstructionBase.setproperties) to map a vector onto a subset of parameters stored in a struct, this functionality allows transforms for different parameter subsets to be constructed for use separately or together: + +```julia +t_a = as((;a = asℝ₊)) +t_b = as((;b = as𝕀)) +t_c = as((;c = TVShift(5) ∘ TVExp())) +t_ab = merge(t_a, t_b) +t_abc = merge(t_ab, t_c) +t_abc = merge(t_a, t_b, t_c) +t_collision = merge(t_a, as((;a = asℝ₋))) # Will have a = asℝ₋, from rightmost +``` + ## Scalar transforms The symbol `∞` is a placeholder for infinity. It does not correspond to `Inf`, but acts as a placeholder for the correct dispatch. `-∞` is valid. diff --git a/src/aggregation.jl b/src/aggregation.jl index f3266d08..e5aca5a6 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -286,6 +286,15 @@ Base.getindex(t::TransformTuple, i::Int) = getindex(_inner(t), i) Base.propertynames(t::TransformTuple) = propertynames(_inner(t)) Base.getproperty(t::TransformTuple, i::Int) = getproperty(_inner(t), i) Base.getproperty(t::TransformTuple{<:NamedTuple}, i::Symbol) = getproperty(_inner(t), i) +""" +$(SIGNATURES) + +Merge multiple `TransformTuple{<:NamedTuple}` by merging the underlying `NamedTuple`s. +""" +function Base.merge(t1::TransformTuple{<:NamedTuple}, + ts::Vararg{TransformTuple{<:NamedTuple}}) + TransformTuple(merge(_inner(t1), map(_inner, ts)...)) +end function _summary_rows(transformation::TransformTuple, mime) inner = _inner(transformation) diff --git a/test/runtests.jl b/test/runtests.jl index 1c84c106..96c8ffb6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -506,6 +506,31 @@ end @test_throws ArgumentError("Property :b not in (:a, :c).") inverse(t, (a = 1.0, c = 2.0)) end +@testset "merging NamedTuple" begin + t1 = as((a = asℝ, b = as𝕀)) + t2 = as((c = CorrCholeskyFactor(3), d = unit_vector_norm(4))) + t3 = as((e = asℝ₊, f = as𝕀)) + tm = @inferred(merge(t1)) + @test tm == t1 + tm = @inferred(merge(t1, t2)) + @test tm == as((a = asℝ, b = as𝕀, c = CorrCholeskyFactor(3), d = unit_vector_norm(4))) + tm = @inferred(merge(t1, t2, t3)) + @test tm == as((a = asℝ, b = as𝕀, c = CorrCholeskyFactor(3), d = unit_vector_norm(4), + e = asℝ₊, f = as𝕀)) + x = randn(dimension(tm)) + y = transform(tm, x) + x′ = inverse(tm, y) + @test x ≈ x′ + # Check merge collision behavior: rightmost gets kept + t4 = as((b = asℝ₋, c = TVScale(2.0))) + tm = @inferred(merge(t1, t4)) + @test tm == as((a = asℝ, b = asℝ₋, c = TVScale(2.0))) + @test tm != as((a = asℝ, b = as𝕀, c = TVScale(2.0))) + tm = @inferred(merge(t1, t4, t2)) + @test tm == as((a = asℝ, b = asℝ₋, c = CorrCholeskyFactor(3), d = unit_vector_norm(4))) + +end + #### #### log density correctness checks ####