Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions src/aggregation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ function transform_with(flag::LogJacFlag, transformation::ArrayTransformation, x
len = prod(dims) # number of elements
𝐼 = reshape(range(index; length = len, step = d), dims)
yℓ = map(index -> ((y, ℓ, _) = transform_with(flag, inner_transformation, x, index); (y, ℓ)), 𝐼)
ℓz = logjac_zero(flag, robust_eltype(x))
ℓz = logjac_zero(flag, _ensure_float(eltype(x)))
index′ = index + d * len
first.(yℓ), isempty(yℓ) ? ℓz : ℓz + sum(last, yℓ), index′
end

function transform_with(flag::LogJacFlag, t::ArrayTransformation{Identity}, x, index)
index′ = index+dimension(t)
y = reshape(x[index:(index′-1)], t.dims)
y, logjac_zero(flag, robust_eltype(x)), index′
y, logjac_zero(flag, _ensure_float(eltype(x))), index′
end

"""
Expand Down Expand Up @@ -133,18 +133,15 @@ dimension(transformation::ViewTransformation) = prod(transformation.dims)
function transform_with(flag::LogJacFlag, t::ViewTransformation, x, index)
index′ = index + dimension(t)
y = reshape(@view(x[index:(index′-1)]), t.dims)
y, logjac_zero(flag, robust_eltype(x)), index′
y, logjac_zero(flag, _ensure_float(eltype(x))), index′
end

function _domain_label(transformation::ViewTransformation, index::Int)
(; dims) = transformation
_array_domain_label(asℝ, dims, index)
end

function inverse_eltype(transformation::ViewTransformation,
::Type{T}) where T <: AbstractArray
_ensure_float(eltype(T))
end
inverse_eltype(transformation::ViewTransformation, ::Type{<:AbstractArray{T}}) where T = T

function inverse_at!(x::AbstractVector, index, transformation::ViewTransformation,
y::AbstractArray)
Expand Down Expand Up @@ -214,8 +211,7 @@ end

function inverse_eltype(transformation::Union{ArrayTransformation,StaticArrayTransformation},
::Type{T}) where T <: AbstractArray
inverse_eltype(transformation.inner_transformation,
_ensure_float(eltype(T)))
inverse_eltype(transformation.inner_transformation, eltype(T))
end

function inverse_at!(x::AbstractVector, index,
Expand Down Expand Up @@ -376,7 +372,7 @@ Helper function for transforming tuples. Used internally, to help type inference
`transfom_tuple`.
"""
_transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ::Tuple{}) =
(), logjac_zero(flag, robust_eltype(x)), index
(), logjac_zero(flag, _ensure_float(eltype(x))), index

function _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ts)
tfirst = first(ts)
Expand Down Expand Up @@ -406,7 +402,7 @@ function _inverse_eltype_tuple(ts::NTransforms{N}, ::Type{T}) where {N,T<:Tuple}
__inverse_eltype_tuple(ts, T)
end
function __inverse_eltype_tuple(ts::NTransforms, ::Type{Tuple{}})
Union{}
Bool
end
function __inverse_eltype_tuple(ts::NTransforms, ::Type{T}) where {T<:Tuple}
promote_type(inverse_eltype(Base.first(ts), fieldtype(T, 1)),
Expand Down
2 changes: 1 addition & 1 deletion src/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function transform_with(logjac_flag::LogJacFlag, t::Constant, x::AbstractVector,
t.value, logjac_zero(logjac_flag, eltype(x)), index
end

inverse_eltype(t::Constant, ::Type) = Union{}
inverse_eltype(t::Constant, ::Type) = Bool

function inverse_at!(x::AbstractVector, index, t::Constant, y)
@argcheck t.value == y
Expand Down
14 changes: 6 additions & 8 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,8 @@ with transform, so the following holds:
inverse(t)(y) == inverse(t, y) == inverse(transform(t))(y)
```

!!! note
`eltype(inverse(t, transform(t, x)))` is not necessarily equal to `eltype(x)`,
it is not guaranteed to be the narrowest possible type, and may change without
warning between versions. Some effort is made to come up with a reasonable
concrete type even in corner cases.
Note that `eltype(inverse(t, y)) ≡ inverse_eltype(t, typeof(y))` holds. See
[`inverse_eltype`](@ref).
"""
inverse(t::AbstractTransform) = Base.Fix1(inverse, t)

Expand Down Expand Up @@ -229,6 +226,9 @@ The element type for vector `x` so that `inverse!(x, t, y::T)` works.

3. No dimension or input compatibility checks are guaranteed to be performed, even for
values.

The value is always a numerical type (ie `<:Real`), but in corner cases (eg
transformations of zero dimension) may not be a float.
"""
function inverse_eltype(t::AbstractTransform, y::T) where T
inverse_eltype(t, T)
Expand Down Expand Up @@ -327,10 +327,8 @@ function transform_and_logjac(t::VectorTransform, x::AbstractVector)
y, ℓ
end

# We want to avoid vectors with non-numerical element types
# Ref https://github.com/tpapp/TransformVariables.jl/issues/132
function inverse(t::VectorTransform, y::T) where T
inverse!(Vector{_ensure_float(inverse_eltype(t, T))}(undef, dimension(t)), t, y)
inverse!(Vector{inverse_eltype(t, T)}(undef, dimension(t)), t, y)
end

"""
Expand Down
29 changes: 21 additions & 8 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,11 @@ function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index::
transform_and_logjac(t, @inbounds x[index])..., index + 1
end

function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y::Real)
function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y)
x[index] = inverse(t, y)
index + 1
end

function inverse_eltype(t::ScalarTransform, ::Type{T}) where T <: Real
# NOTE this is a shortcut to get sensible types for all subtypes of ScalarTransform, which
# we test for. If it breaks it should be extended accordingly.
return Base.promote_typejoin_union(Base.promote_op(inverse, typeof(t), T))
end

_domain_label(::ScalarTransform, index::Int) = ()

####
Expand All @@ -53,6 +47,8 @@ transform(::Identity, x::Real) = x

transform_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x))

inverse_eltype(t::Identity, ::Type{T}) where T = T

inverse(::Identity, x::Number) = x

inverse_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x))
Expand All @@ -72,6 +68,8 @@ transform(::TVExp, x::Real) = exp(x)

transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x

inverse_eltype(t::TVExp, ::Type{T}) where T = _ensure_float(T)

function inverse(::TVExp, x::Number)
log(x)
end
Expand All @@ -89,6 +87,8 @@ transform(::TVLogistic, x::Real) = logistic(x)

transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x)

inverse_eltype(t::TVLogistic, ::Type{T}) where T = _ensure_float(T)

function inverse(::TVLogistic, x::Number)
logit(x)
end
Expand All @@ -108,6 +108,8 @@ transform(t::TVShift, x::Real) = x + t.shift

transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))

inverse_eltype(t::TVShift{S}, ::Type{T}) where {S,T} = typeof(zero(_ensure_float(T)) - zero(S))

inverse(t::TVShift, x::Number) = x - t.shift

inverse_and_logjac(t::TVShift, x::Number) = inverse(t, x), logjac_zero(LogJac(), typeof(x))
Expand All @@ -131,6 +133,8 @@ transform(t::TVScale, x::Real) = t.scale * x

transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale)

inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(oneunit(T) / oneunit(S))

inverse(t::TVScale, x::Number) = x / t.scale

inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale)
Expand All @@ -146,12 +150,14 @@ end
transform(::TVNeg, x::Real) = -x
transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))

inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-oneunit(T))
inverse(::TVNeg, x::Number) = -x
inverse_and_logjac(::TVNeg, x::Number) = -x, logjac_zero(LogJac(), typeof(x))

####
#### composite scalar transforms
####

"""
$(TYPEDEF)

Expand All @@ -172,7 +178,14 @@ function transform_and_logjac(ts::CompositeScalarTransform, x)
end
end

inverse(ts::CompositeScalarTransform, x) = foldl((y, t) -> inverse(t, y), ts.transforms, init=x)
function inverse_eltype(ts::CompositeScalarTransform, ::Type{T}) where T
foldl((T, t) -> inverse_eltype(t, T), ts.transforms; init = T)
end

function inverse(ts::CompositeScalarTransform, x)
foldl((y, t) -> inverse(t, y), ts.transforms, init = x)
end

function inverse_and_logjac(ts::CompositeScalarTransform, x)
foldl(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do (x, logjac), t
nx, nlogjac = inverse_and_logjac(t, x)
Expand Down
19 changes: 9 additions & 10 deletions src/special_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end

function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index)
(; n) = t
T = robust_eltype(x)
T = _ensure_float(eltype(x))
log_r = zero(T)
y = Vector{T}(undef, n)
ℓ = logjac_zero(flag, T)
Expand Down Expand Up @@ -169,7 +169,7 @@ end

function transform_with(flag::LogJacFlag, t::UnitVectorNorm, x::AbstractVector, index)
(; n, chi_prior) = t
T = robust_eltype(x)
T = _ensure_float(eltype(x))
log_r = zero(T)
y = Vector{T}(undef, n)
copyto!(y, 1, x, index, n)
Expand Down Expand Up @@ -227,8 +227,7 @@ dimension(t::UnitSimplex) = t.n - 1

function transform_with(flag::LogJacFlag, t::UnitSimplex, x::AbstractVector, index)
(; n) = t
T = robust_eltype(x)

T = _ensure_float(eltype(x))
ℓ = logjac_zero(flag, T)
stick = one(T)
y = Vector{T}(undef, n)
Expand Down Expand Up @@ -370,20 +369,20 @@ function calculate_corr_cholesky_factor!(U::AbstractMatrix{T}, flag::LogJacFlag,
U, ℓ, index
end

function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor, x::AbstractVector{T},
index) where T
function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor, x::AbstractVector, index)
n = result_size(t)
U, ℓ, index′ = calculate_corr_cholesky_factor!(Matrix{robust_eltype(x)}(undef, n, n),
flag, x, index)
T = _ensure_float(eltype(x))
U, ℓ, index′ = calculate_corr_cholesky_factor!(Matrix{T}(undef, n, n),
flag, x, index)
UpperTriangular(U), ℓ, index′
end

function transform_with(flag::LogJacFlag, transformation::StaticCorrCholeskyFactor{D,S},
x::AbstractVector{T}, index) where {D,S,T}
# NOTE: add an unrolled version for small sizes
E = robust_eltype(x)
E = _ensure_float(eltype(x))
U = if isbitstype(E)
zero(MMatrix{S,S,robust_eltype(x)})
zero(MMatrix{S,S,E})
else
# NOTE: currently allocating because non-bitstype based AD (eg ReverseDiff) does not work with MMatrix
zeros(E, S, S)
Expand Down
48 changes: 17 additions & 31 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,44 +44,30 @@ end
"""
$(SIGNATURES)

Extend element type of argument so that it is closed under the algebra used by this package.
Regularize scalar (element) types to a floating point, falling back to `Float64`.

Pessimistic default for non-real types.
"""
function robust_eltype(::Type{S}) where S
T = eltype(S)
T <: Real ? typeof(√(one(T))) : Any
end

robust_eltype(x::T) where T = robust_eltype(T)
It serves two purposes:

"""
$(SIGNATURES)
1. broaden non-float type (eg `Int`) so that they can accommodate the algebraic results,
eg mapping with `log`,

Regularize input type, preferring a floating point, falling back to `Float64`.
2. assign a sensible fallback type (currently `Float64`) for non-numerical element
types; for example, if the input is `Vector{Any}`, `_ensure_float(Any)` will return
`Float64`,

Internal, not exported.
It is implicitly assumed that the input type is such that it can hold numerical values.
This is typically harmless, since containes for other types (eg `Union{}`, `Nothing`)
will fail anyway.

# Motivation

Type calculations occasionally give types that are too narrow (eg `Union{}` for empty
vectors) or broad. Since this package is primarily intended for *numerical*
calculations, we fall back to something sensible. This function implements the
heuristics for this, and is currently used in inverse element type calculations.
!!! NOTE
Call this function *after* stripping units and similar, so that the input is a
subtype of `Real` in most cases.
"""
function _ensure_float(::Type{T}) where T
if T <: Number # heuristic: it is assumed that every `Number` type defines `float`
return float(T)
else
return Float64
end
end

# pass through containers
_ensure_float(::Type{T}) where {T<:AbstractArray} = T
_ensure_float(::Type) = Float64 # fallback for Any etc.

# special case Union{}
_ensure_float(::Type{Union{}}) = Float64
# heuristic: it is assumed that every `Real` type defines `float`.
# In case this does not hold, the package that defined `T` define `float(::Type{T})`.
_ensure_float(::Type{T}) where {T<:Real} = float(T)

"""
$(SIGNATURES)
Expand Down
Loading
Loading