diff --git a/src/scalar.jl b/src/scalar.jl index 093a36d..bd311da 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -17,11 +17,11 @@ abstract type ScalarTransform <: AbstractTransform end dimension(::ScalarTransform) = 1 -function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index::Int) +function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index) transform(t, @inbounds x[index]), flag, index + 1 end -function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index::Int) +function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index) transform_and_logjac(t, @inbounds x[index])..., index + 1 end @@ -43,15 +43,15 @@ Identity ``x ↦ x``. """ struct Identity <: ScalarTransform end -transform(::Identity, x::Real) = x +transform(::Identity, x::Number) = x -transform_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x)) +transform_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x)) inverse_eltype(t::Identity, ::Type{T}) where T = T inverse(::Identity, x::Number) = x -inverse_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x)) +inverse_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x)) #### #### elementary scalar transforms @@ -64,9 +64,9 @@ Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive real """ struct TVExp <: ScalarTransform end -transform(::TVExp, x::Real) = exp(x) +transform(::TVExp, x::Number) = exp(x) -transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x +transform_and_logjac(t::TVExp, x::Number) = transform(t, x), x inverse_eltype(t::TVExp, ::Type{T}) where T = _ensure_float(T) @@ -83,9 +83,9 @@ Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1). """ struct TVLogistic <: ScalarTransform end -transform(::TVLogistic, x::Real) = logistic(x) +transform(::TVLogistic, x::Number) = logistic(x) -transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x) +transform_and_logjac(t::TVLogistic, x::Number) = transform(t, x), logistic_logjac(x) inverse_eltype(t::TVLogistic, ::Type{T}) where T = _ensure_float(T) @@ -100,13 +100,13 @@ $(TYPEDEF) Shift transformation `x ↦ x + shift`. """ -struct TVShift{T <: Real} <: ScalarTransform +struct TVShift{T} <: ScalarTransform shift::T end -transform(t::TVShift, x::Real) = x + t.shift +transform(t::TVShift, x::Number) = x + t.shift -transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) +transform_and_logjac(t::TVShift, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x)) inverse_eltype(t::TVShift{S}, ::Type{T}) where {S,T} = typeof(zero(_ensure_float(T)) - zero(S)) @@ -129,15 +129,15 @@ end TVScale(scale::T) where {T} = TVScale{T}(scale) -transform(t::TVScale, x::Real) = t.scale * x +transform(t::TVScale, x::Number) = t.scale * x -transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale) +transform_and_logjac(t::TVScale{<:Real}, x::Number) = transform(t, x), log(t.scale) inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(oneunit(T) / oneunit(S)) inverse(t::TVScale, x::Number) = x / t.scale -inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale) +inverse_and_logjac(t::TVScale, x::Number) = inverse(t, x), -log(t.scale) """ $(TYPEDEF) @@ -147,8 +147,8 @@ Negative transformation `x ↦ -x`. struct TVNeg <: ScalarTransform end -transform(::TVNeg, x::Real) = -x -transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) +transform(::TVNeg, x::Number) = -x +transform_and_logjac(t::TVNeg, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x)) inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-oneunit(T)) inverse(::TVNeg, x::Number) = -x diff --git a/src/utilities.jl b/src/utilities.jl index 6462a2e..06a24e8 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -2,7 +2,7 @@ ### logistic and logit ### -function logistic_logjac(x::Real) +function logistic_logjac(x::Number) mx = -abs(x) mx - 2*log1pexp(mx) end