From 5f4b6c1cab910e7c0222a1ee457a1abc38225b31 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 27 Jan 2026 20:40:12 +0330 Subject: [PATCH] test without ode `sol_kwargs` --- benchmark/benchmarks.jl | 10 ---------- test/ci_tests/regression_tests.jl | 5 ----- test/ci_tests/smoke_tests.jl | 6 ------ test/ci_tests/speed_tests.jl | 6 ------ test/quality_tests/checkby_JET_tests.jl | 5 ----- 5 files changed, 32 deletions(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 9913dc3e..1c65edde 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -38,11 +38,6 @@ icnf = ContinuousNormalizingFlows.construct( λ₂ = 1.0f-2, λ₃ = 1.0f-2, rng, - sol_kwargs = (; - save_everystep = false, - alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), - sensealg = SciMLSensitivity.GaussAdjoint(), - ), ) icnf2 = ContinuousNormalizingFlows.construct( @@ -58,11 +53,6 @@ icnf2 = ContinuousNormalizingFlows.construct( λ₂ = 1.0f-2, λ₃ = 1.0f-2, rng, - sol_kwargs = (; - save_everystep = false, - alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), - sensealg = SciMLSensitivity.GaussAdjoint(), - ), ) ps, st = LuxCore.setup(icnf.rng, icnf) diff --git a/test/ci_tests/regression_tests.jl b/test/ci_tests/regression_tests.jl index 1dbdd176..07f09ada 100644 --- a/test/ci_tests/regression_tests.jl +++ b/test/ci_tests/regression_tests.jl @@ -24,11 +24,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Regression Test λ₂ = 1.0f-2, λ₃ = 1.0f-2, rng, - sol_kwargs = (; - save_everystep = false, - alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), - sensealg = SciMLSensitivity.GaussAdjoint(), - ), ) df = DataFrames.DataFrame(transpose(r), :auto) diff --git a/test/ci_tests/smoke_tests.jl b/test/ci_tests/smoke_tests.jl index 0f671f39..26117cab 100644 --- a/test/ci_tests/smoke_tests.jl +++ b/test/ci_tests/smoke_tests.jl @@ -131,11 +131,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be λ₁ = convert(data_type, 1.0e-2), λ₂ = convert(data_type, 1.0e-2), λ₃ = convert(data_type, 1.0e-2), - sol_kwargs = (; - save_everystep = false, - alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), - sensealg = SciMLSensitivity.GaussAdjoint(), - ), ) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) @@ -207,7 +202,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be Test.@testset verbose = true showtiming = true failfast = false "$adtype on loss" for adtype in adtypes - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && ( omode isa ContinuousNormalizingFlows.TrainMode || ( diff --git a/test/ci_tests/speed_tests.jl b/test/ci_tests/speed_tests.jl index 4c081da8..eeb64a96 100644 --- a/test/ci_tests/speed_tests.jl +++ b/test/ci_tests/speed_tests.jl @@ -32,7 +32,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be Test.@testset verbose = true showtiming = true failfast = false "$compute_mode" for compute_mode in compute_modes - @show compute_mode rng = StableRNGs.StableRNG(1) @@ -60,11 +59,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be λ₂ = 1.0f-2, λ₃ = 1.0f-2, rng, - sol_kwargs = (; - save_everystep = false, - alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), - sensealg = SciMLSensitivity.GaussAdjoint(), - ), ) df = DataFrames.DataFrame(transpose(r), :auto) diff --git a/test/quality_tests/checkby_JET_tests.jl b/test/quality_tests/checkby_JET_tests.jl index b125208f..72ea31af 100644 --- a/test/quality_tests/checkby_JET_tests.jl +++ b/test/quality_tests/checkby_JET_tests.jl @@ -114,11 +114,6 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg λ₁ = convert(data_type, 1.0e-2), λ₂ = convert(data_type, 1.0e-2), λ₃ = convert(data_type, 1.0e-2), - sol_kwargs = (; - save_everystep = false, - alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), - sensealg = SciMLSensitivity.GaussAdjoint(), - ), ) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps)