From 7f684520b86123f817a9e3f8a330af4b914c9f6e Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Wed, 3 Sep 2025 20:04:59 +0330 Subject: [PATCH 01/13] mark forwarddiff tests broken --- test/smoke_tests.jl | 6 ++++-- test/speed_tests.jl | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 2c97000e..4ce0ab55 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -194,8 +194,10 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(rand(d, ndata)) Test.@testset "$adtype on loss" for adtype in adtypes - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) - Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) + Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = + compute_mode.adback isa ADTypes.AutoForwardDiff + Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken = + compute_mode.adback isa ADTypes.AutoForwardDiff Test.@testset "$n_epochs for fit" for n_epochs in n_epochs_ if cond diff --git a/test/speed_tests.jl b/test/speed_tests.jl index 01052781..7b90e75d 100644 --- a/test/speed_tests.jl +++ b/test/speed_tests.jl @@ -57,7 +57,8 @@ Test.@testset "Speed Tests" begin model = ContinuousNormalizingFlows.ICNFModel(icnf; batch_size = 0, n_epochs = 5) mach = MLJBase.machine(model, df) - Test.@test !isnothing(MLJBase.fit!(mach)) + Test.@test !isnothing(MLJBase.fit!(mach)) broken = + compute_mode.adback isa ADTypes.AutoForwardDiff @show only(MLJBase.report(mach).stats).time From 78c3cf56ffe76a5e48a813e09c52afdcbc294a49 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Wed, 3 Sep 2025 20:08:49 +0330 Subject: [PATCH 02/13] only if its not all test --- test/smoke_tests.jl | 4 ++-- test/speed_tests.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 4ce0ab55..6aafe953 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -195,9 +195,9 @@ Test.@testset "Smoke Tests" begin Test.@testset "$adtype on loss" for adtype in adtypes Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = - compute_mode.adback isa ADTypes.AutoForwardDiff + GROUP != "All" && compute_mode.adback isa ADTypes.AutoForwardDiff Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken = - compute_mode.adback isa ADTypes.AutoForwardDiff + GROUP != "All" && compute_mode.adback isa ADTypes.AutoForwardDiff Test.@testset "$n_epochs for fit" for n_epochs in n_epochs_ if cond diff --git a/test/speed_tests.jl b/test/speed_tests.jl index 7b90e75d..7b33bb95 100644 --- a/test/speed_tests.jl +++ b/test/speed_tests.jl @@ -58,7 +58,7 @@ Test.@testset "Speed Tests" begin mach = MLJBase.machine(model, df) Test.@test !isnothing(MLJBase.fit!(mach)) broken = - compute_mode.adback isa ADTypes.AutoForwardDiff + GROUP != "All" && compute_mode.adback isa ADTypes.AutoForwardDiff @show only(MLJBase.report(mach).stats).time From 9feec7e07070ff768194d7046157c568b99afd46 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Wed, 3 Sep 2025 21:32:44 +0330 Subject: [PATCH 03/13] maybe a fix --- src/utils.jl | 2 ++ test/smoke_tests.jl | 6 ++---- test/speed_tests.jl | 3 +-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 8579efe2..3f2b712c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,6 +5,7 @@ function jacobian_batched( ) where {T} y = f(xs) z = similar(xs) + z = convert.(promote_type(eltype(z), eltype(f.ps)), z) ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) for i in axes(xs, 1) @@ -23,6 +24,7 @@ function jacobian_batched( ) where {T} y = f(xs) z = similar(xs) + z = convert.(promote_type(eltype(z), eltype(f.ps)), z) ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) for i in axes(xs, 1) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 6aafe953..2c97000e 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -194,10 +194,8 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(rand(d, ndata)) Test.@testset "$adtype on loss" for adtype in adtypes - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = - GROUP != "All" && compute_mode.adback isa ADTypes.AutoForwardDiff - Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken = - GROUP != "All" && compute_mode.adback isa ADTypes.AutoForwardDiff + Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) + Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) Test.@testset "$n_epochs for fit" for n_epochs in n_epochs_ if cond diff --git a/test/speed_tests.jl b/test/speed_tests.jl index 7b33bb95..01052781 100644 --- a/test/speed_tests.jl +++ b/test/speed_tests.jl @@ -57,8 +57,7 @@ Test.@testset "Speed Tests" begin model = ContinuousNormalizingFlows.ICNFModel(icnf; batch_size = 0, n_epochs = 5) mach = MLJBase.machine(model, df) - Test.@test !isnothing(MLJBase.fit!(mach)) broken = - GROUP != "All" && compute_mode.adback isa ADTypes.AutoForwardDiff + Test.@test !isnothing(MLJBase.fit!(mach)) @show only(MLJBase.report(mach).stats).time From b3bb9af7a15893bd1db82fe8c372621067f6b7a3 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Wed, 3 Sep 2025 22:33:23 +0330 Subject: [PATCH 04/13] fix ? --- src/utils.jl | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 3f2b712c..df9780dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,9 +5,13 @@ function jacobian_batched( ) where {T} y = f(xs) z = similar(xs) - z = convert.(promote_type(eltype(z), eltype(f.ps)), z) ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) - res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) + res = Zygote.Buffer( + convert.(promote_type(eltype(xs), eltype(f.ps)), xs), + size(xs, 1), + size(xs, 1), + size(xs, 2), + ) for i in axes(xs, 1) ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) res[i, :, :] = @@ -24,9 +28,13 @@ function jacobian_batched( ) where {T} y = f(xs) z = similar(xs) - z = convert.(promote_type(eltype(z), eltype(f.ps)), z) ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) - res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) + res = Zygote.Buffer( + convert.(promote_type(eltype(xs), eltype(f.ps)), xs), + size(xs, 1), + size(xs, 1), + size(xs, 2), + ) for i in axes(xs, 1) ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) res[:, i, :] = only( From 5dbc35beb9bf4aedd345ebad1ea096a0bc8e683a Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Wed, 3 Sep 2025 23:59:22 +0330 Subject: [PATCH 05/13] mark forward enzyme --- test/smoke_tests.jl | 14 +++++++++----- test/speed_tests.jl | 8 +++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 2c97000e..2a065c4d 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -36,6 +36,9 @@ Test.@testset "Smoke Tests" begin ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()), ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()), + ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()), + ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()), ContinuousNormalizingFlows.DIVecJacVectorMode( ADTypes.AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse), @@ -60,9 +63,6 @@ Test.@testset "Smoke Tests" begin function_annotation = Enzyme.Const, ), ), - ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()), - ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()), - ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()), ] Test.@testset "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in @@ -194,8 +194,12 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(rand(d, ndata)) Test.@testset "$adtype on loss" for adtype in adtypes - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) - Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) + Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} Test.@testset "$n_epochs for fit" for n_epochs in n_epochs_ if cond diff --git a/test/speed_tests.jl b/test/speed_tests.jl index 01052781..411c096f 100644 --- a/test/speed_tests.jl +++ b/test/speed_tests.jl @@ -2,6 +2,8 @@ Test.@testset "Speed Tests" begin compute_modes = ContinuousNormalizingFlows.ComputeMode[ ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()), + ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()), ContinuousNormalizingFlows.DIVecJacMatrixMode( ADTypes.AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse), @@ -14,8 +16,6 @@ Test.@testset "Speed Tests" begin function_annotation = Enzyme.Const, ), ), - ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()), - ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()), ] Test.@testset "$compute_mode" for compute_mode in compute_modes @@ -57,7 +57,9 @@ Test.@testset "Speed Tests" begin model = ContinuousNormalizingFlows.ICNFModel(icnf; batch_size = 0, n_epochs = 5) mach = MLJBase.machine(model, df) - Test.@test !isnothing(MLJBase.fit!(mach)) + Test.@test !isnothing(MLJBase.fit!(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} @show only(MLJBase.report(mach).stats).time From a709393865d93a5613ff90247c0b3e31e3bfc048 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 4 Sep 2025 01:09:30 +0330 Subject: [PATCH 06/13] more broken --- test/smoke_tests.jl | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 2a065c4d..f120039c 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -212,14 +212,22 @@ Test.@testset "Smoke Tests" begin ) mach = MLJBase.machine(model, (df, df2)) - Test.@test !isnothing(MLJBase.fit!(mach)) - Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) - Test.@test !isnothing(MLJBase.fitted_params(mach)) + Test.@test !isnothing(MLJBase.fit!(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} Test.@test !isnothing(MLJBase.serializable(mach)) Test.@test !isnothing( ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2), - ) + ) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} else model = ContinuousNormalizingFlows.ICNFModel( icnf; @@ -230,12 +238,20 @@ Test.@testset "Smoke Tests" begin ) mach = MLJBase.machine(model, df) - Test.@test !isnothing(MLJBase.fit!(mach)) - Test.@test !isnothing(MLJBase.transform(mach, df)) - Test.@test !isnothing(MLJBase.fitted_params(mach)) + Test.@test !isnothing(MLJBase.fit!(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.transform(mach, df)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} Test.@test !isnothing(MLJBase.serializable(mach)) - Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) + Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} end end end From af1f9c35db52450c6aff22c2c1d62b268241232f Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 4 Sep 2025 10:06:33 +0330 Subject: [PATCH 07/13] fix --- test/smoke_tests.jl | 8 ++++++-- test/speed_tests.jl | 16 +++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index f120039c..214a1619 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -221,7 +221,9 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = GROUP != "All" && compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.serializable(mach)) + Test.@test !isnothing(MLJBase.serializable(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} Test.@test !isnothing( ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2), @@ -247,7 +249,9 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = GROUP != "All" && compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.serializable(mach)) + Test.@test !isnothing(MLJBase.serializable(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) broken = GROUP != "All" && diff --git a/test/speed_tests.jl b/test/speed_tests.jl index 411c096f..77cb080f 100644 --- a/test/speed_tests.jl +++ b/test/speed_tests.jl @@ -10,12 +10,12 @@ Test.@testset "Speed Tests" begin function_annotation = Enzyme.Const, ), ), - ContinuousNormalizingFlows.DIJacVecMatrixMode( - ADTypes.AutoEnzyme(; - mode = Enzyme.set_runtime_activity(Enzyme.Forward), - function_annotation = Enzyme.Const, - ), - ), + # ContinuousNormalizingFlows.DIJacVecMatrixMode( + # ADTypes.AutoEnzyme(; + # mode = Enzyme.set_runtime_activity(Enzyme.Forward), + # function_annotation = Enzyme.Const, + # ), + # ), ] Test.@testset "$compute_mode" for compute_mode in compute_modes @@ -57,9 +57,7 @@ Test.@testset "Speed Tests" begin model = ContinuousNormalizingFlows.ICNFModel(icnf; batch_size = 0, n_epochs = 5) mach = MLJBase.machine(model, df) - Test.@test !isnothing(MLJBase.fit!(mach)) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + MLJBase.fit!(mach) @show only(MLJBase.report(mach).stats).time From 197c5b7fa912d9c4a7f807ae19258efa9064c041 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 4 Sep 2025 10:08:37 +0330 Subject: [PATCH 08/13] clean --- test/smoke_tests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 214a1619..71f2646b 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -45,15 +45,15 @@ Test.@testset "Smoke Tests" begin function_annotation = Enzyme.Const, ), ), - ContinuousNormalizingFlows.DIJacVecVectorMode( + ContinuousNormalizingFlows.DIVecJacMatrixMode( ADTypes.AutoEnzyme(; - mode = Enzyme.set_runtime_activity(Enzyme.Forward), + mode = Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation = Enzyme.Const, ), ), - ContinuousNormalizingFlows.DIVecJacMatrixMode( + ContinuousNormalizingFlows.DIJacVecVectorMode( ADTypes.AutoEnzyme(; - mode = Enzyme.set_runtime_activity(Enzyme.Reverse), + mode = Enzyme.set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const, ), ), From 2e49d5260cf69ec960a16142182db69595678e18 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 4 Sep 2025 12:43:24 +0330 Subject: [PATCH 09/13] no planar --- test/smoke_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 71f2646b..9204bac1 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -15,7 +15,7 @@ Test.@testset "Smoke Tests" begin else Bool[false, true], Bool[false, true] end - planars = Bool[false, true] + planars = Bool[false] nvars_ = Int[2] ndata_ = Int[4] n_epochs_ = Int[2] From 8ec6d0f8d96f1ea48ea12b231badc475597ae060 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 4 Sep 2025 12:55:16 +0330 Subject: [PATCH 10/13] skip it --- test/smoke_tests.jl | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 9204bac1..71a6b10c 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -15,7 +15,7 @@ Test.@testset "Smoke Tests" begin else Bool[false, true], Bool[false, true] end - planars = Bool[false] + planars = Bool[false, true] nvars_ = Int[2] ndata_ = Int[4] n_epochs_ = Int[2] @@ -196,10 +196,17 @@ Test.@testset "Smoke Tests" begin Test.@testset "$adtype on loss" for adtype in adtypes Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} skip = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + planar + Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} skip = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + planar Test.@testset "$n_epochs for fit" for n_epochs in n_epochs_ if cond @@ -214,7 +221,10 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(MLJBase.fit!(mach)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} skip = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + planar Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) broken = GROUP != "All" && compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} @@ -242,7 +252,10 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(MLJBase.fit!(mach)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} skip = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + planar Test.@test !isnothing(MLJBase.transform(mach, df)) broken = GROUP != "All" && compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} From 7338a1719a02d23d3642d44fc6b4e9d82bd48a68 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 4 Sep 2025 13:50:07 +0330 Subject: [PATCH 11/13] comment it --- test/smoke_tests.jl | 85 +++++++++++++-------------------------------- 1 file changed, 24 insertions(+), 61 deletions(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 71a6b10c..d9a59504 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -51,18 +51,18 @@ Test.@testset "Smoke Tests" begin function_annotation = Enzyme.Const, ), ), - ContinuousNormalizingFlows.DIJacVecVectorMode( - ADTypes.AutoEnzyme(; - mode = Enzyme.set_runtime_activity(Enzyme.Forward), - function_annotation = Enzyme.Const, - ), - ), - ContinuousNormalizingFlows.DIJacVecMatrixMode( - ADTypes.AutoEnzyme(; - mode = Enzyme.set_runtime_activity(Enzyme.Forward), - function_annotation = Enzyme.Const, - ), - ), + # ContinuousNormalizingFlows.DIJacVecVectorMode( + # ADTypes.AutoEnzyme(; + # mode = Enzyme.set_runtime_activity(Enzyme.Forward), + # function_annotation = Enzyme.Const, + # ), + # ), + # ContinuousNormalizingFlows.DIJacVecMatrixMode( + # ADTypes.AutoEnzyme(; + # mode = Enzyme.set_runtime_activity(Enzyme.Forward), + # function_annotation = Enzyme.Const, + # ), + # ), ] Test.@testset "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in @@ -194,19 +194,8 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(rand(d, ndata)) Test.@testset "$adtype on loss" for adtype in adtypes - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} skip = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && - planar - - Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} skip = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && - planar + Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) + Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) Test.@testset "$n_epochs for fit" for n_epochs in n_epochs_ if cond @@ -219,27 +208,14 @@ Test.@testset "Smoke Tests" begin ) mach = MLJBase.machine(model, (df, df2)) - Test.@test !isnothing(MLJBase.fit!(mach)) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} skip = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && - planar - Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.serializable(mach)) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.fit!(mach)) + Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) + Test.@test !isnothing(MLJBase.fitted_params(mach)) + Test.@test !isnothing(MLJBase.serializable(mach)) Test.@test !isnothing( ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2), - ) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + ) else model = ContinuousNormalizingFlows.ICNFModel( icnf; @@ -250,25 +226,12 @@ Test.@testset "Smoke Tests" begin ) mach = MLJBase.machine(model, df) - Test.@test !isnothing(MLJBase.fit!(mach)) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} skip = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && - planar - Test.@test !isnothing(MLJBase.transform(mach, df)) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.serializable(mach)) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.fit!(mach)) + Test.@test !isnothing(MLJBase.transform(mach, df)) + Test.@test !isnothing(MLJBase.fitted_params(mach)) + Test.@test !isnothing(MLJBase.serializable(mach)) - Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) broken = - GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) end end end From aa836c3c3cc61d7bc863e71a79683aa2b6b8e749 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 4 Sep 2025 13:56:26 +0330 Subject: [PATCH 12/13] skip them --- test/smoke_tests.jl | 31 ++++++++++++++++++------------- test/speed_tests.jl | 18 ++++++++++++------ 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index d9a59504..7bb4f450 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -22,6 +22,7 @@ Test.@testset "Smoke Tests" begin data_types = Type{<:AbstractFloat}[Float32] devices = MLDataDevices.AbstractDevice[MLDataDevices.cpu_device()] adtypes = ADTypes.AbstractADType[ADTypes.AutoZygote(), + # ADTypes.AutoForwardDiff(), # ADTypes.AutoEnzyme(; # mode = Enzyme.set_runtime_activity(Enzyme.Reverse), # function_annotation = Enzyme.Const, @@ -30,7 +31,6 @@ Test.@testset "Smoke Tests" begin # mode = Enzyme.set_runtime_activity(Enzyme.Forward), # function_annotation = Enzyme.Const, # ), - # ADTypes.AutoForwardDiff(), ] compute_modes = ContinuousNormalizingFlows.ComputeMode[ ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), @@ -51,18 +51,18 @@ Test.@testset "Smoke Tests" begin function_annotation = Enzyme.Const, ), ), - # ContinuousNormalizingFlows.DIJacVecVectorMode( - # ADTypes.AutoEnzyme(; - # mode = Enzyme.set_runtime_activity(Enzyme.Forward), - # function_annotation = Enzyme.Const, - # ), - # ), - # ContinuousNormalizingFlows.DIJacVecMatrixMode( - # ADTypes.AutoEnzyme(; - # mode = Enzyme.set_runtime_activity(Enzyme.Forward), - # function_annotation = Enzyme.Const, - # ), - # ), + ContinuousNormalizingFlows.DIJacVecVectorMode( + ADTypes.AutoEnzyme(; + mode = Enzyme.set_runtime_activity(Enzyme.Forward), + function_annotation = Enzyme.Const, + ), + ), + ContinuousNormalizingFlows.DIJacVecMatrixMode( + ADTypes.AutoEnzyme(; + mode = Enzyme.set_runtime_activity(Enzyme.Forward), + function_annotation = Enzyme.Const, + ), + ), ] Test.@testset "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in @@ -193,6 +193,11 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(rand(d)) Test.@test !isnothing(rand(d, ndata)) + if GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + continue + end + Test.@testset "$adtype on loss" for adtype in adtypes Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) diff --git a/test/speed_tests.jl b/test/speed_tests.jl index 77cb080f..8e173b7c 100644 --- a/test/speed_tests.jl +++ b/test/speed_tests.jl @@ -10,12 +10,12 @@ Test.@testset "Speed Tests" begin function_annotation = Enzyme.Const, ), ), - # ContinuousNormalizingFlows.DIJacVecMatrixMode( - # ADTypes.AutoEnzyme(; - # mode = Enzyme.set_runtime_activity(Enzyme.Forward), - # function_annotation = Enzyme.Const, - # ), - # ), + ContinuousNormalizingFlows.DIJacVecMatrixMode( + ADTypes.AutoEnzyme(; + mode = Enzyme.set_runtime_activity(Enzyme.Forward), + function_annotation = Enzyme.Const, + ), + ), ] Test.@testset "$compute_mode" for compute_mode in compute_modes @@ -54,6 +54,12 @@ Test.@testset "Speed Tests" begin ) df = DataFrames.DataFrame(transpose(r), :auto) + + if GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + continue + end + model = ContinuousNormalizingFlows.ICNFModel(icnf; batch_size = 0, n_epochs = 5) mach = MLJBase.machine(model, df) From 4eb943d5131d8b9ea97a0d9cf96ae711dd49910f Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 4 Sep 2025 15:51:29 +0330 Subject: [PATCH 13/13] collect eachcol --- src/exts/dist_ext/core_cond_icnf.jl | 8 ++++++-- src/exts/dist_ext/core_icnf.jl | 8 ++++++-- src/exts/mlj_ext/core_cond_icnf.jl | 5 +++-- src/exts/mlj_ext/core_icnf.jl | 8 ++++++-- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/exts/dist_ext/core_cond_icnf.jl b/src/exts/dist_ext/core_cond_icnf.jl index f0ee3841..ca12e213 100644 --- a/src/exts/dist_ext/core_cond_icnf.jl +++ b/src/exts/dist_ext/core_cond_icnf.jl @@ -19,6 +19,7 @@ function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real}) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} first(inference(d.m, d.mode, x, d.ys, d.ps, d.st)) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + @warn "to compute by matrices, data should be a matrix." first(Distributions._logpdf(d, hcat(x))) else error("Not Implemented") @@ -26,7 +27,8 @@ function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real}) end function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real}) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - Distributions._logpdf.(d, eachcol(A)) + @warn "to compute by vectors, data should be a vector." + Distributions._logpdf.(d, collect(collect.(eachcol(A)))) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) else @@ -41,6 +43,7 @@ function Distributions._rand!( return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} x .= generate(d.m, d.mode, d.ys, d.ps, d.st) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + @warn "to compute by matrices, data should be a matrix." x .= Distributions._rand!(rng, d, hcat(x)) else error("Not Implemented") @@ -52,7 +55,8 @@ function Distributions._rand!( A::AbstractMatrix{<:Real}, ) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - A .= hcat(Distributions._rand!.(rng, d, eachcol(A))...) + @warn "to compute by vectors, data should be a vector." + A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) else diff --git a/src/exts/dist_ext/core_icnf.jl b/src/exts/dist_ext/core_icnf.jl index b22b9dc2..2c937229 100644 --- a/src/exts/dist_ext/core_icnf.jl +++ b/src/exts/dist_ext/core_icnf.jl @@ -14,6 +14,7 @@ function Distributions._logpdf(d::ICNFDist, x::AbstractVector{<:Real}) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} first(inference(d.m, d.mode, x, d.ps, d.st)) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + @warn "to compute by matrices, data should be a matrix." first(Distributions._logpdf(d, hcat(x))) else error("Not Implemented") @@ -22,7 +23,8 @@ end function Distributions._logpdf(d::ICNFDist, A::AbstractMatrix{<:Real}) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - Distributions._logpdf.(d, eachcol(A)) + @warn "to compute by vectors, data should be a vector." + Distributions._logpdf.(d, collect(collect.(eachcol(A)))) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} first(inference(d.m, d.mode, A, d.ps, d.st)) else @@ -38,6 +40,7 @@ function Distributions._rand!( return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} x .= generate(d.m, d.mode, d.ps, d.st) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + @warn "to compute by matrices, data should be a matrix." x .= Distributions._rand!(rng, d, hcat(x)) else error("Not Implemented") @@ -49,7 +52,8 @@ function Distributions._rand!( A::AbstractMatrix{<:Real}, ) return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - A .= hcat(Distributions._rand!.(rng, d, eachcol(A))...) + @warn "to compute by vectors, data should be a vector." + A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2)) else diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 2aaf38a8..07f85601 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -80,12 +80,13 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew) (ps, st) = fitresult logp̂x = if model.m.compute_mode isa VectorMode + @warn "to compute by vectors, data should be a vector." broadcast( function (x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) return first(inference(model.m, TestMode(), x, y, ps, st)) end, - eachcol(xnew), - eachcol(ynew), + collect(collect.(eachcol(xnew))), + collect(collect.(eachcol(ynew))), ) elseif model.m.compute_mode isa MatrixMode first(inference(model.m, TestMode(), xnew, ynew, ps, st)) diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 57c5ff46..ba9c4a86 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -74,9 +74,13 @@ function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew) (ps, st) = fitresult logp̂x = if model.m.compute_mode isa VectorMode - broadcast(function (x::AbstractVector{<:Real}) + @warn "to compute by vectors, data should be a vector." + broadcast( + function (x::AbstractVector{<:Real}) return first(inference(model.m, TestMode(), x, ps, st)) - end, eachcol(xnew)) + end, + collect(collect.(eachcol(xnew))), + ) elseif model.m.compute_mode isa MatrixMode first(inference(model.m, TestMode(), xnew, ps, st)) else