diff --git a/src/exts/dist_ext/core_cond_icnf.jl b/src/exts/dist_ext/core_cond_icnf.jl index f0ee3841..ca12e213 100644 --- a/src/exts/dist_ext/core_cond_icnf.jl +++ b/src/exts/dist_ext/core_cond_icnf.jl @@ -19,6 +19,7 @@ function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real}) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} first(inference(d.m, d.mode, x, d.ys, d.ps, d.st)) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + @warn "to compute by matrices, data should be a matrix." first(Distributions._logpdf(d, hcat(x))) else error("Not Implemented") @@ -26,7 +27,8 @@ function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real}) end function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real}) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - Distributions._logpdf.(d, eachcol(A)) + @warn "to compute by vectors, data should be a vector." + Distributions._logpdf.(d, collect(collect.(eachcol(A)))) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) else @@ -41,6 +43,7 @@ function Distributions._rand!( return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} x .= generate(d.m, d.mode, d.ys, d.ps, d.st) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + @warn "to compute by matrices, data should be a matrix." x .= Distributions._rand!(rng, d, hcat(x)) else error("Not Implemented") @@ -52,7 +55,8 @@ function Distributions._rand!( A::AbstractMatrix{<:Real}, ) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - A .= hcat(Distributions._rand!.(rng, d, eachcol(A))...) + @warn "to compute by vectors, data should be a vector." + A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) else diff --git a/src/exts/dist_ext/core_icnf.jl b/src/exts/dist_ext/core_icnf.jl index b22b9dc2..2c937229 100644 --- a/src/exts/dist_ext/core_icnf.jl +++ b/src/exts/dist_ext/core_icnf.jl @@ -14,6 +14,7 @@ function Distributions._logpdf(d::ICNFDist, x::AbstractVector{<:Real}) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} first(inference(d.m, d.mode, x, d.ps, d.st)) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + @warn "to compute by matrices, data should be a matrix." first(Distributions._logpdf(d, hcat(x))) else error("Not Implemented") @@ -22,7 +23,8 @@ end function Distributions._logpdf(d::ICNFDist, A::AbstractMatrix{<:Real}) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - Distributions._logpdf.(d, eachcol(A)) + @warn "to compute by vectors, data should be a vector." + Distributions._logpdf.(d, collect(collect.(eachcol(A)))) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} first(inference(d.m, d.mode, A, d.ps, d.st)) else @@ -38,6 +40,7 @@ function Distributions._rand!( return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} x .= generate(d.m, d.mode, d.ps, d.st) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + @warn "to compute by matrices, data should be a matrix." x .= Distributions._rand!(rng, d, hcat(x)) else error("Not Implemented") @@ -49,7 +52,8 @@ function Distributions._rand!( A::AbstractMatrix{<:Real}, ) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - A .= hcat(Distributions._rand!.(rng, d, eachcol(A))...) + @warn "to compute by vectors, data should be a vector." + A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2)) else diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 2aaf38a8..07f85601 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -80,12 +80,13 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew) (ps, st) = fitresult logp̂x = if model.m.compute_mode isa VectorMode + @warn "to compute by vectors, data should be a vector." broadcast( function (x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) return first(inference(model.m, TestMode(), x, y, ps, st)) end, - eachcol(xnew), - eachcol(ynew), + collect(collect.(eachcol(xnew))), + collect(collect.(eachcol(ynew))), ) elseif model.m.compute_mode isa MatrixMode first(inference(model.m, TestMode(), xnew, ynew, ps, st)) diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 57c5ff46..ba9c4a86 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -74,9 +74,13 @@ function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew) (ps, st) = fitresult logp̂x = if model.m.compute_mode isa VectorMode - broadcast(function (x::AbstractVector{<:Real}) + @warn "to compute by vectors, data should be a vector." + broadcast( + function (x::AbstractVector{<:Real}) return first(inference(model.m, TestMode(), x, ps, st)) - end, eachcol(xnew)) + end, + collect(collect.(eachcol(xnew))), + ) elseif model.m.compute_mode isa MatrixMode first(inference(model.m, TestMode(), xnew, ps, st)) else diff --git a/src/utils.jl b/src/utils.jl index 8579efe2..df9780dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,7 +6,12 @@ function jacobian_batched( y = f(xs) z = similar(xs) ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) - res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) + res = Zygote.Buffer( + convert.(promote_type(eltype(xs), eltype(f.ps)), xs), + size(xs, 1), + size(xs, 1), + size(xs, 2), + ) for i in axes(xs, 1) ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) res[i, :, :] = @@ -24,7 +29,12 @@ function jacobian_batched( y = f(xs) z = similar(xs) ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) - res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) + res = Zygote.Buffer( + convert.(promote_type(eltype(xs), eltype(f.ps)), xs), + size(xs, 1), + size(xs, 1), + size(xs, 2), + ) for i in axes(xs, 1) ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) res[:, i, :] = only( diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 2c97000e..7bb4f450 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -22,6 +22,7 @@ Test.@testset "Smoke Tests" begin data_types = Type{<:AbstractFloat}[Float32] devices = MLDataDevices.AbstractDevice[MLDataDevices.cpu_device()] adtypes = ADTypes.AbstractADType[ADTypes.AutoZygote(), + # ADTypes.AutoForwardDiff(), # ADTypes.AutoEnzyme(; # mode = Enzyme.set_runtime_activity(Enzyme.Reverse), # function_annotation = Enzyme.Const, @@ -30,27 +31,29 @@ Test.@testset "Smoke Tests" begin # mode = Enzyme.set_runtime_activity(Enzyme.Forward), # function_annotation = Enzyme.Const, # ), - # ADTypes.AutoForwardDiff(), ] compute_modes = ContinuousNormalizingFlows.ComputeMode[ ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()), ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()), + ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()), + ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()), ContinuousNormalizingFlows.DIVecJacVectorMode( ADTypes.AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation = Enzyme.Const, ), ), - ContinuousNormalizingFlows.DIJacVecVectorMode( + ContinuousNormalizingFlows.DIVecJacMatrixMode( ADTypes.AutoEnzyme(; - mode = Enzyme.set_runtime_activity(Enzyme.Forward), + mode = Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation = Enzyme.Const, ), ), - ContinuousNormalizingFlows.DIVecJacMatrixMode( + ContinuousNormalizingFlows.DIJacVecVectorMode( ADTypes.AutoEnzyme(; - mode = Enzyme.set_runtime_activity(Enzyme.Reverse), + mode = Enzyme.set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const, ), ), @@ -60,9 +63,6 @@ Test.@testset "Smoke Tests" begin function_annotation = Enzyme.Const, ), ), - ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()), - ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()), - ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()), ] Test.@testset "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in @@ -193,6 +193,11 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(rand(d)) Test.@test !isnothing(rand(d, ndata)) + if GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + continue + end + Test.@testset "$adtype on loss" for adtype in adtypes Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) diff --git a/test/speed_tests.jl b/test/speed_tests.jl index 01052781..8e173b7c 100644 --- a/test/speed_tests.jl +++ b/test/speed_tests.jl @@ -2,6 +2,8 @@ Test.@testset "Speed Tests" begin compute_modes = ContinuousNormalizingFlows.ComputeMode[ ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()), + ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()), ContinuousNormalizingFlows.DIVecJacMatrixMode( ADTypes.AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse), @@ -14,8 +16,6 @@ Test.@testset "Speed Tests" begin function_annotation = Enzyme.Const, ), ), - ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()), - ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()), ] Test.@testset "$compute_mode" for compute_mode in compute_modes @@ -54,10 +54,16 @@ Test.@testset "Speed Tests" begin ) df = DataFrames.DataFrame(transpose(r), :auto) + + if GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + continue + end + model = ContinuousNormalizingFlows.ICNFModel(icnf; batch_size = 0, n_epochs = 5) mach = MLJBase.machine(model, df) - Test.@test !isnothing(MLJBase.fit!(mach)) + MLJBase.fit!(mach) @show only(MLJBase.report(mach).stats).time