From c8576fa1edefa35e3de4b5eaba38a85f1983cf6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Tue, 25 Nov 2025 14:45:02 +0100 Subject: [PATCH 01/10] Reorganize inverse_eltype calculations. - remove `robust_eltype`, replace with `_ensure_float` - remove `Union{}` as a corner case for empty (named)tuples and `Constant`, use `Bool` - clarify API guarantees (`inverse_eltype` is always `<:Real`) - remove `_ensure_float` from `inverse!` definition, clarify invariant (eltype is always the same as `inverse_eltype`) - incidental: fix various minor code redundancies - fixes #148 (add tests) --- src/aggregation.jl | 18 +++++++--------- src/constant.jl | 2 +- src/generic.jl | 14 ++++++------- src/special_arrays.jl | 19 ++++++++--------- src/utilities.jl | 45 ++++++++++++++-------------------------- test/runtests.jl | 48 ++++++++++++++++++++++++++++--------------- 6 files changed, 70 insertions(+), 76 deletions(-) diff --git a/src/aggregation.jl b/src/aggregation.jl index f3266d08..c673abeb 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -77,7 +77,7 @@ function transform_with(flag::LogJacFlag, transformation::ArrayTransformation, x len = prod(dims) # number of elements 𝐼 = reshape(range(index; length = len, step = d), dims) yℓ = map(index -> ((y, ℓ, _) = transform_with(flag, inner_transformation, x, index); (y, ℓ)), 𝐼) - ℓz = logjac_zero(flag, robust_eltype(x)) + ℓz = logjac_zero(flag, _ensure_float(eltype(x))) index′ = index + d * len first.(yℓ), isempty(yℓ) ? ℓz : ℓz + sum(last, yℓ), index′ end @@ -85,7 +85,7 @@ end function transform_with(flag::LogJacFlag, t::ArrayTransformation{Identity}, x, index) index′ = index+dimension(t) y = reshape(x[index:(index′-1)], t.dims) - y, logjac_zero(flag, robust_eltype(x)), index′ + y, logjac_zero(flag, _ensure_float(eltype(x))), index′ end """ @@ -133,7 +133,7 @@ dimension(transformation::ViewTransformation) = prod(transformation.dims) function transform_with(flag::LogJacFlag, t::ViewTransformation, x, index) index′ = index + dimension(t) y = reshape(@view(x[index:(index′-1)]), t.dims) - y, logjac_zero(flag, robust_eltype(x)), index′ + y, logjac_zero(flag, _ensure_float(eltype(x))), index′ end function _domain_label(transformation::ViewTransformation, index::Int) @@ -141,10 +141,7 @@ function _domain_label(transformation::ViewTransformation, index::Int) _array_domain_label(asℝ, dims, index) end -function inverse_eltype(transformation::ViewTransformation, - ::Type{T}) where T <: AbstractArray - _ensure_float(eltype(T)) -end +inverse_eltype(transformation::ViewTransformation, ::Type{<:AbstractArray{T}}) where T = T function inverse_at!(x::AbstractVector, index, transformation::ViewTransformation, y::AbstractArray) @@ -214,8 +211,7 @@ end function inverse_eltype(transformation::Union{ArrayTransformation,StaticArrayTransformation}, ::Type{T}) where T <: AbstractArray - inverse_eltype(transformation.inner_transformation, - _ensure_float(eltype(T))) + inverse_eltype(transformation.inner_transformation, eltype(T)) end function inverse_at!(x::AbstractVector, index, @@ -376,7 +372,7 @@ Helper function for transforming tuples. Used internally, to help type inference `transfom_tuple`. """ _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ::Tuple{}) = - (), logjac_zero(flag, robust_eltype(x)), index + (), logjac_zero(flag, _ensure_float(eltype(x))), index function _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ts) tfirst = first(ts) @@ -406,7 +402,7 @@ function _inverse_eltype_tuple(ts::NTransforms{N}, ::Type{T}) where {N,T<:Tuple} __inverse_eltype_tuple(ts, T) end function __inverse_eltype_tuple(ts::NTransforms, ::Type{Tuple{}}) - Union{} + Bool end function __inverse_eltype_tuple(ts::NTransforms, ::Type{T}) where {T<:Tuple} promote_type(inverse_eltype(Base.first(ts), fieldtype(T, 1)), diff --git a/src/constant.jl b/src/constant.jl index a4fdcd64..3cd317ed 100644 --- a/src/constant.jl +++ b/src/constant.jl @@ -19,7 +19,7 @@ function transform_with(logjac_flag::LogJacFlag, t::Constant, x::AbstractVector, t.value, logjac_zero(logjac_flag, eltype(x)), index end -inverse_eltype(t::Constant, ::Type) = Union{} +inverse_eltype(t::Constant, ::Type) = Bool function inverse_at!(x::AbstractVector, index, t::Constant, y) @argcheck t.value == y diff --git a/src/generic.jl b/src/generic.jl index f9fe6caf..707c6be4 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -165,11 +165,8 @@ with transform, so the following holds: inverse(t)(y) == inverse(t, y) == inverse(transform(t))(y) ``` -!!! note - `eltype(inverse(t, transform(t, x)))` is not necessarily equal to `eltype(x)`, - it is not guaranteed to be the narrowest possible type, and may change without - warning between versions. Some effort is made to come up with a reasonable - concrete type even in corner cases. +Note that `eltype(inverse(t, y)) ≡ inverse_eltype(t, typeof(y))` holds. See +[`inverse_eltype`](@ref). """ inverse(t::AbstractTransform) = Base.Fix1(inverse, t) @@ -229,6 +226,9 @@ The element type for vector `x` so that `inverse!(x, t, y::T)` works. 3. No dimension or input compatibility checks are guaranteed to be performed, even for values. + +The value is always a numerical type (ie `<:Real`), but in corner cases (eg +transformations of zero dimension) may not be a float. """ function inverse_eltype(t::AbstractTransform, y::T) where T inverse_eltype(t, T) @@ -327,10 +327,8 @@ function transform_and_logjac(t::VectorTransform, x::AbstractVector) y, ℓ end -# We want to avoid vectors with non-numerical element types -# Ref https://github.com/tpapp/TransformVariables.jl/issues/132 function inverse(t::VectorTransform, y::T) where T - inverse!(Vector{_ensure_float(inverse_eltype(t, T))}(undef, dimension(t)), t, y) + inverse!(Vector{inverse_eltype(t, T)}(undef, dimension(t)), t, y) end """ diff --git a/src/special_arrays.jl b/src/special_arrays.jl index 2ee53080..d8b9ac6a 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -82,7 +82,7 @@ end function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index) (; n) = t - T = robust_eltype(x) + T = _ensure_float(eltype(x)) log_r = zero(T) y = Vector{T}(undef, n) ℓ = logjac_zero(flag, T) @@ -169,7 +169,7 @@ end function transform_with(flag::LogJacFlag, t::UnitVectorNorm, x::AbstractVector, index) (; n, chi_prior) = t - T = robust_eltype(x) + T = _ensure_float(eltype(x)) log_r = zero(T) y = Vector{T}(undef, n) copyto!(y, 1, x, index, n) @@ -227,8 +227,7 @@ dimension(t::UnitSimplex) = t.n - 1 function transform_with(flag::LogJacFlag, t::UnitSimplex, x::AbstractVector, index) (; n) = t - T = robust_eltype(x) - + T = _ensure_float(eltype(x)) ℓ = logjac_zero(flag, T) stick = one(T) y = Vector{T}(undef, n) @@ -370,20 +369,20 @@ function calculate_corr_cholesky_factor!(U::AbstractMatrix{T}, flag::LogJacFlag, U, ℓ, index end -function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor, x::AbstractVector{T}, - index) where T +function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor, x::AbstractVector, index) n = result_size(t) - U, ℓ, index′ = calculate_corr_cholesky_factor!(Matrix{robust_eltype(x)}(undef, n, n), - flag, x, index) + T = _ensure_float(eltype(x)) + U, ℓ, index′ = calculate_corr_cholesky_factor!(Matrix{T}(undef, n, n), + flag, x, index) UpperTriangular(U), ℓ, index′ end function transform_with(flag::LogJacFlag, transformation::StaticCorrCholeskyFactor{D,S}, x::AbstractVector{T}, index) where {D,S,T} # NOTE: add an unrolled version for small sizes - E = robust_eltype(x) + E = _ensure_float(eltype(x)) U = if isbitstype(E) - zero(MMatrix{S,S,robust_eltype(x)}) + zero(MMatrix{S,S,E}) else # NOTE: currently allocating because non-bitstype based AD (eg ReverseDiff) does not work with MMatrix zeros(E, S, S) diff --git a/src/utilities.jl b/src/utilities.jl index 2dbb3cff..a825e5b5 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -44,44 +44,29 @@ end """ $(SIGNATURES) -Extend element type of argument so that it is closed under the algebra used by this package. +Regularize scalar (element) types to a floating point, falling back to `Float64`. -Pessimistic default for non-real types. -""" -function robust_eltype(::Type{S}) where S - T = eltype(S) - T <: Real ? typeof(√(one(T))) : Any -end - -robust_eltype(x::T) where T = robust_eltype(T) +It serves two purposes: -""" -$(SIGNATURES) +1. broaden non-float type (eg `Int`) so that they can accommodate the algebraic results, + eg mapping with `log`, -Regularize input type, preferring a floating point, falling back to `Float64`. +2. assign a sensible fallback type (currently `Float64`) for non-numerical element + types; for example, if the input is `Vector{Any}`, `_ensure_float(Any)` will return + `Float64`, -Internal, not exported. - -# Motivation - -Type calculations occasionally give types that are too narrow (eg `Union{}` for empty -vectors) or broad. Since this package is primarily intended for *numerical* -calculations, we fall back to something sensible. This function implements the -heuristics for this, and is currently used in inverse element type calculations. +It is implicitly assumed that the input type is such that it can hold numerical values. +This is typically harmless, since containes for other types (eg `Union{}`, `Nothing`) +will fail anyway. """ -function _ensure_float(::Type{T}) where T - if T <: Number # heuristic: it is assumed that every `Number` type defines `float` - return float(T) - else - return Float64 - end -end +_ensure_float(::Type) = Float64 -# pass through containers -_ensure_float(::Type{T}) where {T<:AbstractArray} = T +# heuristic: it is assumed that every `Number` type defines `float`. +# In case this does not hold, define the relevant method. +_ensure_float(::Type{T}) where {T<:Real} = float(T) # special case Union{} -_ensure_float(::Type{Union{}}) = Float64 +# _ensure_float(::Type{Union{}}) = Float64 """ $(SIGNATURES) diff --git a/test/runtests.jl b/test/runtests.jl index 1c84c106..b5303319 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -900,34 +900,50 @@ end # Empty `inverse(::VectorTransform, _)` for a in (3, 4.7, [5], 3f0, 4.7f0, [5f0]) x = @inferred(inverse(as((; a = Constant(a))), (; a))) - @test x isa Vector{Float64} + @test x isa Vector{Bool} @test isempty(x) x = @inferred(inverse(as((Constant(a),)), (a,))) - @test x isa Vector{Float64} + @test x isa Vector{Bool} @test isempty(x) x = @inferred(inverse(as(Vector, Constant(a), 1), [a])) - @test x isa Vector{Float64} + @test x isa Vector{Bool} @test isempty(x) end # Element type of `inverse(::VectorTransform, _)` - for a in (3, 3.0, 3f0) - T = float(typeof(a)) - - x = @inferred(inverse(as((; a = asℝ)), (; a))) - @test x isa Vector{T} - @test x == [3] + for t in (asℝ, asℝ₊) + for a in (3, 3.0, 3f0) + z = inverse(t, a) + T = typeof(z) + + x = @inferred(inverse(as((; a = t)), (; a))) + @test x isa Vector{T} + @test x == [z] + + x = @inferred(inverse(as((t,)), (a,))) + @test x isa Vector{T} + @test x == [z] + + x = @inferred(inverse(as(Vector, t, 1), [a])) + @test x isa Vector{T} + @test x == [z] + end + end +end - x = @inferred(inverse(as((asℝ,)), (a,))) - @test x isa Vector{T} - @test x == [3] +@testset "nested transformations element type" begin + t = as(Vector, as((a = asℝ,)), 4) + x = zeros(dimension(t)) + y = transform(t, x) + @test @inferred(inverse(t, y)) == x + @test inverse_eltype(t, typeof(y)) ≡ eltype(x) - x = @inferred(inverse(as(Vector, asℝ, 1), [a])) - @test x isa Vector{T} - @test x == [3] - end + t0 = as(Vector, as((a = asℝ,)), 0) + y0 = @inferred(transform(t0, Float64[])) + @test @inferred(transform(t0, Float64[])) == y0 + @test inverse_eltype(t, typeof(y0)) ≡ Float64 end #### From 505c6bc5ad4ff9686c7fcdcf7aa6f6ac71c7832e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Tue, 25 Nov 2025 16:02:47 +0100 Subject: [PATCH 02/10] Refactor scalar inverse_eltype calculations. - get rid of `promote_op` - add test for `Any[...]` input - enable tests which now pass (unrelated) --- src/scalar.jl | 27 ++++++++++++++++++++------- test/runtests.jl | 15 +++++++++------ 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index e6ba07e6..6ff07d41 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -30,12 +30,6 @@ function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y::Real) index + 1 end -function inverse_eltype(t::ScalarTransform, ::Type{T}) where T <: Real - # NOTE this is a shortcut to get sensible types for all subtypes of ScalarTransform, which - # we test for. If it breaks it should be extended accordingly. - return Base.promote_typejoin_union(Base.promote_op(inverse, typeof(t), T)) -end - _domain_label(::ScalarTransform, index::Int) = () #### @@ -53,6 +47,8 @@ transform(::Identity, x::Real) = x transform_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x)) +inverse_eltype(t::Identity, ::Type{T}) where T = T + inverse(::Identity, x::Number) = x inverse_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x)) @@ -72,6 +68,8 @@ transform(::TVExp, x::Real) = exp(x) transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x +inverse_eltype(t::TVExp, ::Type{T}) where T = _ensure_float(T) + function inverse(::TVExp, x::Number) log(x) end @@ -89,6 +87,8 @@ transform(::TVLogistic, x::Real) = logistic(x) transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x) +inverse_eltype(t::TVLogistic, ::Type{T}) where T = _ensure_float(T) + function inverse(::TVLogistic, x::Number) logit(x) end @@ -108,6 +108,8 @@ transform(t::TVShift, x::Real) = x + t.shift transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) +inverse_eltype(t::TVShift{S}, ::Type{T}) where {S,T} = typeof(one(_ensure_float(T)) - one(S)) + inverse(t::TVShift, x::Number) = x - t.shift inverse_and_logjac(t::TVShift, x::Number) = inverse(t, x), logjac_zero(LogJac(), typeof(x)) @@ -131,6 +133,8 @@ transform(t::TVScale, x::Real) = t.scale * x transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale) +inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(one(_ensure_float(T)) / one(S)) + inverse(t::TVScale, x::Number) = x / t.scale inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale) @@ -146,12 +150,14 @@ end transform(::TVNeg, x::Real) = -x transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) +inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-one(T)) inverse(::TVNeg, x::Number) = -x inverse_and_logjac(::TVNeg, x::Number) = -x, logjac_zero(LogJac(), typeof(x)) #### #### composite scalar transforms #### + """ $(TYPEDEF) @@ -172,7 +178,14 @@ function transform_and_logjac(ts::CompositeScalarTransform, x) end end -inverse(ts::CompositeScalarTransform, x) = foldl((y, t) -> inverse(t, y), ts.transforms, init=x) +function inverse_eltype(ts::CompositeScalarTransform, ::Type{T}) where T + foldl((T, t) -> inverse_eltype(t, T), ts.transforms; init = T) +end + +function inverse(ts::CompositeScalarTransform, x) + foldl((y, t) -> inverse(t, y), ts.transforms, init = x) +end + function inverse_and_logjac(ts::CompositeScalarTransform, x) foldl(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do (x, logjac), t nx, nlogjac = inverse_and_logjac(t, x) diff --git a/test/runtests.jl b/test/runtests.jl index b5303319..ee81ce78 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -462,14 +462,14 @@ end za = as(Array, asℝ₊, 0) @test dimension(zt) == dimension(znt) == 0 @test @inferred(transform(zt, Float64[])) == () - @test_skip inverse(zt, ()) == [] + @test inverse(zt, ()) == [] @test @inferred(transform_and_logjac(zt, Float64[])) == ((), 0.0) @test @inferred(transform(znt, Float64[])) == NamedTuple() @test @inferred(transform_and_logjac(znt, Float64[])) == (NamedTuple(), 0.0) - @test_skip inverse(znt, ()) == [] + @test inverse(znt, (;)) == [] @test @inferred(transform(za, Float64[])) == Float64[] @test @inferred(transform_and_logjac(za, Float64[])) == (Float64[], 0.0) - @test_skip inverse(za, []) == [] + @test inverse(za, []) == [] end @testset "nested combinations" begin @@ -646,7 +646,6 @@ end # end # end - @testset "inference of nested tuples" begin # An MWE adapted from a real-life problem ABOVE1 = as(Real, 1, ∞) # transformation for μ ≥ 1 @@ -741,7 +740,6 @@ end @test inverse(t)(y) == inverse(t, y) == inverse(transform(t))(y) ≈ x end - @testset "ChangesOfVariables" begin t = as(Real, 1.0, 3.0) f = transform(t) @@ -750,7 +748,6 @@ end ChangesOfVariables.test_with_logabsdet_jacobian(inv_f, 1.7, ForwardDiff.derivative) end - @testset "InverseFunctions" begin t = as(Real, 1.0, 3.0) f = transform(t) @@ -946,6 +943,12 @@ end @test inverse_eltype(t, typeof(y0)) ≡ Float64 end +@testset "element type corner cases" begin + t = as(Vector, asℝ₊, 3) + x = @inferred inverse(t, Any[1, 2.0, 3f0]) + @test eltype(x) ≡ Float64 +end + #### #### static analysis with JET #### From 3a652ec40eb5c5fc945ad154cc1e277da4fdbd19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Tue, 25 Nov 2025 17:43:02 +0100 Subject: [PATCH 03/10] tiny coverage improvement --- test/runtests.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index ee81ce78..3c6f6727 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -827,8 +827,10 @@ end t = as((a = asℝ₊, b = as(Array, asℝ₋, 1, 1), c = corr_cholesky_factor(2), - d = as(SVector{2}, asℝ₊))) - @test [domain_label(t, i) for i in 1:dimension(t)] == [".a", ".b[1,1]", ".c[1]", ".d[1]", ".d[2]"] + d = as(SVector{2}, asℝ₊), + v = as(view, 2))) + @test [domain_label(t, i) for i in 1:dimension(t)] == + [".a", ".b[1,1]", ".c[1]", ".d[1]", ".d[2]", ".v[1]", ".v[2]"] end @testset "static arrays inference" begin From 26778153f4047167512f538ab3129ac7bd27341c Mon Sep 17 00:00:00 2001 From: "Tamas K. Papp" Date: Mon, 1 Dec 2025 10:23:14 +0100 Subject: [PATCH 04/10] Update src/scalar.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: David Müller-Widmann --- src/scalar.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scalar.jl b/src/scalar.jl index 6ff07d41..081f8cd5 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -108,7 +108,7 @@ transform(t::TVShift, x::Real) = x + t.shift transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) -inverse_eltype(t::TVShift{S}, ::Type{T}) where {S,T} = typeof(one(_ensure_float(T)) - one(S)) +inverse_eltype(t::TVShift{S}, ::Type{T}) where {S,T} = typeof(zero(_ensure_float(T)) - zero(S)) inverse(t::TVShift, x::Number) = x - t.shift From 6cbee05da5bc91876bb5b11847e6875d314c6a2e Mon Sep 17 00:00:00 2001 From: "Tamas K. Papp" Date: Mon, 1 Dec 2025 10:23:33 +0100 Subject: [PATCH 05/10] Update src/scalar.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: David Müller-Widmann --- src/scalar.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scalar.jl b/src/scalar.jl index 081f8cd5..45dcbb33 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -133,7 +133,7 @@ transform(t::TVScale, x::Real) = t.scale * x transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale) -inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(one(_ensure_float(T)) / one(S)) +inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(oneunit(_ensure_float(T)) / oneunit(S)) inverse(t::TVScale, x::Number) = x / t.scale From 857a51f6c95b144d25dee3a4ce209ffc10a04f5a Mon Sep 17 00:00:00 2001 From: "Tamas K. Papp" Date: Mon, 1 Dec 2025 10:24:41 +0100 Subject: [PATCH 06/10] Update src/scalar.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: David Müller-Widmann --- src/scalar.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scalar.jl b/src/scalar.jl index 45dcbb33..22e8298e 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -150,7 +150,7 @@ end transform(::TVNeg, x::Real) = -x transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) -inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-one(T)) +inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-oneunit(T)) inverse(::TVNeg, x::Number) = -x inverse_and_logjac(::TVNeg, x::Number) = -x, logjac_zero(LogJac(), typeof(x)) From e41d5f8a3b82c480713e614dd86bcf0648491834 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Mon, 1 Dec 2025 10:59:26 +0100 Subject: [PATCH 07/10] remove unused code, clarify how we use `float` --- src/utilities.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/utilities.jl b/src/utilities.jl index a825e5b5..cf25a639 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -61,13 +61,10 @@ will fail anyway. """ _ensure_float(::Type) = Float64 -# heuristic: it is assumed that every `Number` type defines `float`. -# In case this does not hold, define the relevant method. +# heuristic: it is assumed that every `Real` type defines `float`. +# In case this does not hold, the package that defined `T` define `float(::Type{T})`. _ensure_float(::Type{T}) where {T<:Real} = float(T) -# special case Union{} -# _ensure_float(::Type{Union{}}) = Float64 - """ $(SIGNATURES) From 27f41ccc9df7e9cd3122a2380c509b1603187bc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Mon, 1 Dec 2025 11:30:54 +0100 Subject: [PATCH 08/10] add another test for Unitful, fix inverse_eltype --- src/scalar.jl | 4 ++-- src/utilities.jl | 6 +++++- test/runtests.jl | 5 +++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 22e8298e..093a36dd 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -25,7 +25,7 @@ function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index:: transform_and_logjac(t, @inbounds x[index])..., index + 1 end -function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y::Real) +function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y) x[index] = inverse(t, y) index + 1 end @@ -133,7 +133,7 @@ transform(t::TVScale, x::Real) = t.scale * x transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale) -inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(oneunit(_ensure_float(T)) / oneunit(S)) +inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(oneunit(T) / oneunit(S)) inverse(t::TVScale, x::Number) = x / t.scale diff --git a/src/utilities.jl b/src/utilities.jl index cf25a639..6462a2e9 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -58,8 +58,12 @@ It serves two purposes: It is implicitly assumed that the input type is such that it can hold numerical values. This is typically harmless, since containes for other types (eg `Union{}`, `Nothing`) will fail anyway. + +!!! NOTE + Call this function *after* stripping units and similar, so that the input is a + subtype of `Real` in most cases. """ -_ensure_float(::Type) = Float64 +_ensure_float(::Type) = Float64 # fallback for Any etc. # heuristic: it is assumed that every `Real` type defines `float`. # In case this does not hold, the package that defined `T` define `float(::Type{T})`. diff --git a/test/runtests.jl b/test/runtests.jl index 3c6f6727..183e61ee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -97,6 +97,11 @@ end end end +@testset "vector with TVScale unitful" begin + t = as(Vector, TVScale(2u"m"), 4) + test_transformation(t, y -> y isa Vector && eltype(y) ≡ typeof(2.0u"m"); jac = false) +end + @testset "composite scalar transformations" begin all_transforms = [TVShift(3.0), TVScale(2.0), TVExp(), TVLogistic(), TVNeg()] for t1 in all_transforms, t2 in all_transforms, t3 in all_transforms From 808e48038725cc9761eb00b566bec02026038a25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Mon, 1 Dec 2025 11:31:05 +0100 Subject: [PATCH 09/10] make sure complex numbers do not leak through --- test/runtests.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 183e61ee..97fd2458 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -956,6 +956,11 @@ end @test eltype(x) ≡ Float64 end +@testset "inverse error on complex elements" begin + t = as(Vector, asℝ₊, 3) + @test_throws InexactError inverse(t, fill(Complex(0, 1), 3)) +end + #### #### static analysis with JET #### From 527a75cd3827d7b9898272fd062168ab21964164 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Mon, 1 Dec 2025 11:56:32 +0100 Subject: [PATCH 10/10] tighten inference in tests, remove redundant line --- test/utilities.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/utilities.jl b/test/utilities.jl index 490a7539..bc1ecb20 100644 --- a/test/utilities.jl +++ b/test/utilities.jl @@ -55,11 +55,10 @@ function test_transformation(t::AbstractTransform, is_valid_y; x isa ScalarTransform && @test dimension(x) == 1 y = @inferred transform(t, x) @test is_valid_y(y) - @test transform(t, x) == y if jac y2, lj = @inferred transform_and_logjac(t, x) @test y2 == y - jc = TransformVariables.logprior(t, y) + jc = @inferred TransformVariables.logprior(t, y) if !iszero(jc) @test TransformVariables.nonzero_logprior(t) == true end @@ -70,9 +69,9 @@ function test_transformation(t::AbstractTransform, is_valid_y; end end if test_inverse - x2 = inverse(t, y) + x2 = @inferred inverse(t, y) @test x ≈ x2 atol = 1e-6 - ι = inverse(t) + ι = @inferred inverse(t) @test x2 == ι(y) end end