From 9be07e13ac2387014f5c4a2925200b26102fddb2 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 6 Sep 2025 17:46:14 +0330 Subject: [PATCH 1/3] test more enzyme --- test/checkby_JET_tests.jl | 7 +++++++ test/smoke_tests.jl | 8 ++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/test/checkby_JET_tests.jl b/test/checkby_JET_tests.jl index 6119889e..021bb0d9 100644 --- a/test/checkby_JET_tests.jl +++ b/test/checkby_JET_tests.jl @@ -116,6 +116,13 @@ Test.@testset "CheckByJET" begin ps = device(ps) st = device(st) + if GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + planar && + VERSION >= v"1.11" + continue + end + if cond ContinuousNormalizingFlows.loss(icnf, omode, r, r2, ps, st) JET.test_call( diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 99f6f708..695a80c9 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -200,10 +200,10 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(rand(d)) Test.@test !isnothing(rand(d, ndata)) - if GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - continue - end + # if GROUP != "All" && + # compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + # continue + # end Test.@testset "$adtype on loss" for adtype in adtypes Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) From 81a384b8e761a0a33d0c94bfb7e483d15c6874e4 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 6 Sep 2025 19:24:34 +0330 Subject: [PATCH 2/3] mark them broken --- test/smoke_tests.jl | 53 ++++++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 695a80c9..09bd93fe 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -200,14 +200,13 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(rand(d)) Test.@test !isnothing(rand(d, ndata)) - # if GROUP != "All" && - # compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - # continue - # end - Test.@testset "$adtype on loss" for adtype in adtypes - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) - Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) + Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} if cond model = ContinuousNormalizingFlows.CondICNFModel( @@ -218,14 +217,24 @@ Test.@testset "Smoke Tests" begin ) mach = MLJBase.machine(model, (df, df2)) - Test.@test !isnothing(MLJBase.fit!(mach)) - Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) - Test.@test !isnothing(MLJBase.fitted_params(mach)) - Test.@test !isnothing(MLJBase.serializable(mach)) + Test.@test !isnothing(MLJBase.fit!(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.serializable(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} Test.@test !isnothing( ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2), - ) + ) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} else model = ContinuousNormalizingFlows.ICNFModel( icnf; @@ -235,12 +244,22 @@ Test.@testset "Smoke Tests" begin ) mach = MLJBase.machine(model, df) - Test.@test !isnothing(MLJBase.fit!(mach)) - Test.@test !isnothing(MLJBase.transform(mach, df)) - Test.@test !isnothing(MLJBase.fitted_params(mach)) - Test.@test !isnothing(MLJBase.serializable(mach)) + Test.@test !isnothing(MLJBase.fit!(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.transform(mach, df)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.serializable(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) + Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} end end end From f26d86e756acbeba060493c9b2885077754ae26e Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 6 Sep 2025 20:57:31 +0330 Subject: [PATCH 3/3] fix maybe --- test/smoke_tests.jl | 96 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 84 insertions(+), 12 deletions(-) diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 09bd93fe..377130b6 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -203,10 +203,22 @@ Test.@testset "Smoke Tests" begin Test.@testset "$adtype on loss" for adtype in adtypes Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) if cond model = ContinuousNormalizingFlows.CondICNFModel( @@ -219,22 +231,52 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(MLJBase.fit!(mach)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) Test.@test !isnothing(MLJBase.serializable(mach)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) Test.@test !isnothing( ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2), ) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) else model = ContinuousNormalizingFlows.ICNFModel( icnf; @@ -246,20 +288,50 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(MLJBase.fit!(mach)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) Test.@test !isnothing(MLJBase.transform(mach, df)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) Test.@test !isnothing(MLJBase.serializable(mach)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) broken = GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) end end end