From ea589ab8ef461ff806dd6916de895230fb3d35d0 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 1 Jan 2024 00:39:45 +0100 Subject: [PATCH 01/25] wip --- Project.toml | 5 +++++ ext/ProximalAlgorithmsZygoteExt.jl | 18 ++++++++++++++++++ src/ProximalAlgorithms.jl | 7 +++++-- src/algorithms/davis_yin.jl | 6 ++++-- src/algorithms/fast_forward_backward.jl | 6 ++++-- src/algorithms/forward_backward.jl | 6 ++++-- src/algorithms/li_lin.jl | 9 ++++++--- src/algorithms/panoc.jl | 12 ++++++++---- src/algorithms/panocplus.jl | 12 ++++++++---- src/algorithms/primal_dual.jl | 6 ++++-- src/algorithms/sfista.jl | 6 ++++-- src/algorithms/zerofpr.jl | 10 +++++++--- src/utilities/ad.jl | 14 -------------- src/utilities/fb_tools.jl | 22 ++++++++++++++-------- test/Project.toml | 1 + test/utilities/test_ad.jl | 1 + 16 files changed, 93 insertions(+), 48 deletions(-) create mode 100644 ext/ProximalAlgorithmsZygoteExt.jl delete mode 100644 src/utilities/ad.jl diff --git a/Project.toml b/Project.toml index ab0bfb2..5963e2b 100644 --- a/Project.toml +++ b/Project.toml @@ -6,8 +6,13 @@ version = "0.5.5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" + +[weakdeps] Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[extensions] +ProximalAlgorithmsZygoteExt = "Zygote" + [compat] LinearAlgebra = "1.2" Printf = "1.2" diff --git a/ext/ProximalAlgorithmsZygoteExt.jl b/ext/ProximalAlgorithmsZygoteExt.jl new file mode 100644 index 0000000..c5f466c --- /dev/null +++ b/ext/ProximalAlgorithmsZygoteExt.jl @@ -0,0 +1,18 @@ +module ProximalAlgorithmsZygoteExt + +using ProximalAlgorithms +using Zygote: pullback + +struct ZygoteFunction{F} + f::F +end + +(f::ZygoteFunction)(x) = f.f(x) + +function ProximalAlgorithms.eval_with_pullback(f::ZygoteFunction, x) + out, pb = pullback(f, x) + zygote_pullback() = pb(one(out))[1] + return out, zygote_pullback +end + +end diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index 743c0ea..0bd51cc 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -1,14 +1,17 @@ module ProximalAlgorithms using ProximalCore -using ProximalCore: prox, prox!, gradient, gradient! +using ProximalCore: prox, prox! const RealOrComplex{R} = Union{R,Complex{R}} const Maybe{T} = Union{T,Nothing} +# gradient computation + +eval_with_pullback(f::F, x::X) where {F, X} = error("undefined `eval_with_pullback` for function type $F on argument of type $X") + # various utilities -include("utilities/ad.jl") include("utilities/fb_tools.jl") include("utilities/iteration_tools.jl") diff --git a/src/algorithms/davis_yin.jl b/src/algorithms/davis_yin.jl index 93ad084..baf7be8 100644 --- a/src/algorithms/davis_yin.jl +++ b/src/algorithms/davis_yin.jl @@ -55,7 +55,8 @@ end function Base.iterate(iter::DavisYinIteration) z = copy(iter.x0) xg, = prox(iter.g, z, iter.gamma) - grad_f_xg, = gradient(iter.f, xg) + _, pb = eval_with_pullback(iter.f, xg) + grad_f_xg = pb() z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg xh, = prox(iter.h, z_half, iter.gamma) res = xh - xg @@ -66,7 +67,8 @@ end function Base.iterate(iter::DavisYinIteration, state::DavisYinState) prox!(state.xg, iter.g, state.z, iter.gamma) - gradient!(state.grad_f_xg, iter.f, state.xg) + _, pb = eval_with_pullback(iter.f, state.xg) + state.grad_f_xg .= pb() state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg prox!(state.xh, iter.h, state.z_half, iter.gamma) state.res .= state.xh .- state.xg diff --git a/src/algorithms/fast_forward_backward.jl b/src/algorithms/fast_forward_backward.jl index b797478..bd96a21 100644 --- a/src/algorithms/fast_forward_backward.jl +++ b/src/algorithms/fast_forward_backward.jl @@ -68,7 +68,8 @@ end function Base.iterate(iter::FastForwardBackwardIteration) x = copy(iter.x0) - grad_f_x, f_x = gradient(iter.f, x) + f_x, pb = eval_with_pullback(iter.f, x) + grad_f_x = pb() gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x z, g_z = prox(iter.g, y, gamma) @@ -103,7 +104,8 @@ function Base.iterate(iter::FastForwardBackwardIteration{R}, state::FastForwardB state.x .= state.z .+ beta .* (state.z .- state.z_prev) state.z_prev, state.z = state.z, state.z_prev - state.f_x = gradient!(state.grad_f_x, iter.f, state.x) + state.f_x, pb = eval_with_pullback(iter.f, state.x) + state.grad_f_x .= pb() state.y .= state.x .- state.gamma .* state.grad_f_x state.g_z = prox!(state.z, iter.g, state.y, state.gamma) state.res .= state.x .- state.z diff --git a/src/algorithms/forward_backward.jl b/src/algorithms/forward_backward.jl index 6d7f43b..e651b58 100644 --- a/src/algorithms/forward_backward.jl +++ b/src/algorithms/forward_backward.jl @@ -59,7 +59,8 @@ end function Base.iterate(iter::ForwardBackwardIteration) x = copy(iter.x0) - grad_f_x, f_x = gradient(iter.f, x) + f_x, pb = eval_with_pullback(iter.f, x) + grad_f_x = pb() gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x z, g_z = prox(iter.g, y, gamma) @@ -81,7 +82,8 @@ function Base.iterate(iter::ForwardBackwardIteration{R}, state::ForwardBackwardS state.grad_f_x, state.grad_f_z = state.grad_f_z, state.grad_f_x else state.x, state.z = state.z, state.x - state.f_x = gradient!(state.grad_f_x, iter.f, state.x) + state.f_x, pb = eval_with_pullback(iter.f, state.x) + state.grad_f_x .= pb() end state.y .= state.x .- state.gamma .* state.grad_f_x diff --git a/src/algorithms/li_lin.jl b/src/algorithms/li_lin.jl index cb572a8..3be37c7 100644 --- a/src/algorithms/li_lin.jl +++ b/src/algorithms/li_lin.jl @@ -62,7 +62,8 @@ end function Base.iterate(iter::LiLinIteration{R}) where {R} y = copy(iter.x0) - grad_f_y, f_y = gradient(iter.f, y) + f_y, pb = eval_with_pullback(iter.f, y) + grad_f_y = pb() # TODO: initialize gamma if not provided # TODO: authors suggest Barzilai-Borwein rule? @@ -102,7 +103,8 @@ function Base.iterate( else # TODO: re-use available space in state? # TODO: backtrack gamma at x - grad_f_x, f_x = gradient(iter.f, x) + f_x, pb = eval_with_pullback(iter.f, x) + grad_f_x = pb() x_forward = state.x - state.gamma .* grad_f_x v, g_v = prox(iter.g, x_forward, state.gamma) Fv = iter.f(v) + g_v @@ -121,7 +123,8 @@ function Base.iterate( Fx = Fv end - state.f_y = gradient!(state.grad_f_y, iter.f, state.y) + state.f_y, pb = eval_with_pullback(iter.f, state.y) + state.grad_f_y .= pb() state.y_forward .= state.y .- state.gamma .* state.grad_f_y state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma) diff --git a/src/algorithms/panoc.jl b/src/algorithms/panoc.jl index 5f38dcb..78eb7ff 100644 --- a/src/algorithms/panoc.jl +++ b/src/algorithms/panoc.jl @@ -86,7 +86,8 @@ f_model(iter::PANOCIteration, state::PANOCState) = f_model(state.f_Ax, state.At_ function Base.iterate(iter::PANOCIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x - grad_f_Ax, f_Ax = gradient(iter.f, Ax) + f_Ax, pb = eval_with_pullback(iter.f, Ax) + grad_f_Ax = pb() gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -152,7 +153,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where state.x_d .= state.x .+ state.d state.Ax_d .= state.Ax .+ state.Ad - state.f_Ax_d = gradient!(state.grad_f_Ax_d, iter.f, state.Ax_d) + state.f_Ax_d, pb = eval_with_pullback(iter.f, state.Ax_d) + state.grad_f_Ax_d .= pb() mul!(state.At_grad_f_Ax_d, adjoint(iter.A), state.grad_f_Ax_d) copyto!(state.x, state.x_d) @@ -189,7 +191,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where # along a line using interpolation and linear combinations # this allows saving operations if isinf(f_Az) - f_Az = gradient!(state.grad_f_Az, iter.f, state.Az) + f_Az, pb = eval_with_pullback(iter.f, state.Az) + state.grad_f_Az .= pb() end if isinf(c) mul!(state.At_grad_f_Az, iter.A', state.grad_f_Az) @@ -203,7 +206,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where else # otherwise, in the general case where f is only smooth, we compute # one gradient and matvec per backtracking step - state.f_Ax = gradient!(state.grad_f_Ax, iter.f, state.Ax) + state.f_Ax, pb = eval_with_pullback(iter.f, state.Ax) + state.grad_f_Ax .= pb() mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) end diff --git a/src/algorithms/panocplus.jl b/src/algorithms/panocplus.jl index cc9e994..5d113c4 100644 --- a/src/algorithms/panocplus.jl +++ b/src/algorithms/panocplus.jl @@ -79,7 +79,8 @@ f_model(iter::PANOCplusIteration, state::PANOCplusState) = f_model(state.f_Ax, s function Base.iterate(iter::PANOCplusIteration{R}) where {R} x = copy(iter.x0) Ax = iter.A * x - grad_f_Ax, f_Ax = gradient(iter.f, Ax) + f_Ax, pb = eval_with_pullback(iter.f, Ax) + grad_f_Ax = pb() gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -97,7 +98,8 @@ function Base.iterate(iter::PANOCplusIteration{R}) where {R} ) else mul!(state.Az, iter.A, state.z) - gradient!(state.grad_f_Az, iter.f, state.Az) + _, pb = eval_with_pullback(iter.f, state.Az) + state.grad_f_Az = pb() end mul!(state.At_grad_f_Az, adjoint(iter.A), state.grad_f_Az) return state, state @@ -152,7 +154,8 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where end mul!(state.Ax, iter.A, state.x) - state.f_Ax = gradient!(state.grad_f_Ax, iter.f, state.Ax) + state.f_Ax, pb = eval_with_pullback(iter.f, state.Ax) + state.grad_f_Ax .= pb() mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) state.y .= state.x .- state.gamma .* state.At_grad_f_Ax @@ -162,7 +165,8 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where f_Az_upp = f_model(iter, state) mul!(state.Az, iter.A, state.z) - f_Az = gradient!(state.grad_f_Az, iter.f, state.Az) + f_Az, pb = eval_with_pullback(iter.f, state.Az) + state.grad_f_Az .= pb() if (iter.gamma === nothing || iter.adaptive == true) tol = 10 * eps(R) * (1 + abs(f_Az)) if f_Az > f_Az_upp + tol && state.gamma >= iter.minimum_gamma diff --git a/src/algorithms/primal_dual.jl b/src/algorithms/primal_dual.jl index 1ac83ba..9a4da4e 100644 --- a/src/algorithms/primal_dual.jl +++ b/src/algorithms/primal_dual.jl @@ -167,7 +167,8 @@ end function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(iter.x0), y=copy(iter.y0))) # perform xbar-update step - gradient!(state.gradf, iter.f, state.x) + _, pb = eval_with_pullback(iter.f, state.x) + state.gradf .= pb() mul!(state.temp_x, iter.L', state.y) state.temp_x .+= state.gradf state.temp_x .*= -iter.gamma[1] @@ -175,7 +176,8 @@ function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(i prox!(state.xbar, iter.g, state.temp_x, iter.gamma[1]) # perform ybar-update step - gradient!(state.gradl, convex_conjugate(iter.l), state.y) + _, pb = eval_with_pullback(convex_conjugate(iter.l), state.y) + state.gradl .= pb() state.temp_x .= iter.theta .* state.xbar .+ (1 - iter.theta) .* state.x mul!(state.temp_y, iter.L, state.temp_x) state.temp_y .-= state.gradl diff --git a/src/algorithms/sfista.jl b/src/algorithms/sfista.jl index 954309d..05c041d 100644 --- a/src/algorithms/sfista.jl +++ b/src/algorithms/sfista.jl @@ -71,7 +71,8 @@ function Base.iterate( state.a = (state.τ + sqrt(state.τ ^ 2 + 4 * state.τ * state.APrev)) / 2 state.A = state.APrev + state.a state.xt .= (state.APrev / state.A) .* state.yPrev + (state.a / state.A) .* state.xPrev - gradient!(state.gradf_xt, iter.f, state.xt) + _, pb = eval_with_pullback(iter.f, state.xt) + state.gradf_xt .= pb() λ2 = state.λ / (1 + state.λ * iter.mf) # FISTA acceleration steps. prox!(state.y, iter.g, state.xt - λ2 * state.gradf_xt, λ2) @@ -93,7 +94,8 @@ function check_sc(state::SFISTAState, iter::SFISTAIteration, tol, termination_ty else # Classic (approximate) first-order stationary point [4]. The main inclusion is: r ∈ ∇f(y) + ∂h(y). λ2 = state.λ / (1 + state.λ * iter.mf) - gradf_y, = gradient(iter.f, state.y) + _, pb = eval_with_pullback(iter.f, state.y) + gradf_y = pb() r = gradf_y - state.gradf_xt + (state.xt - state.y) / λ2 res = norm(r) end diff --git a/src/algorithms/zerofpr.jl b/src/algorithms/zerofpr.jl index fee5ea5..7d5de61 100644 --- a/src/algorithms/zerofpr.jl +++ b/src/algorithms/zerofpr.jl @@ -84,7 +84,8 @@ f_model(iter::ZeroFPRIteration, state::ZeroFPRState) = f_model(state.f_Ax, state function Base.iterate(iter::ZeroFPRIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x - grad_f_Ax, f_Ax = gradient(iter.f, Ax) + f_Ax, pb = eval_with_pullback(iter.f, Ax) + grad_f_Ax = pb() gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -130,7 +131,9 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R f_Axbar_upp, f_Axbar else mul!(state.Axbar, iter.A, state.xbar) - f_model(iter, state), gradient!(state.grad_f_Axbar, iter.f, state.Axbar) + f_Axbar, pb = eval_with_pullback(iter.f, state.Axbar) + state.grad_f_Axbar .= pb() + f_model(iter, state), f_Axbar end # compute FBE @@ -164,7 +167,8 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R state.x .= state.xbar_prev .+ state.tau .* state.d state.Ax .= state.Axbar .+ state.tau .* state.Ad # TODO: can precompute most of next line in case f is quadratic - state.f_Ax = gradient!(state.grad_f_Ax, iter.f, state.Ax) + state.f_Ax, pb = eval_with_pullback(iter.f, state.Ax) + state.grad_f_Ax .= pb() mul!(state.At_grad_f_Ax, iter.A', state.grad_f_Ax) state.y .= state.x .- state.gamma .* state.At_grad_f_Ax state.g_xbar = prox!(state.xbar, iter.g, state.y, state.gamma) diff --git a/src/utilities/ad.jl b/src/utilities/ad.jl deleted file mode 100644 index f1e43b7..0000000 --- a/src/utilities/ad.jl +++ /dev/null @@ -1,14 +0,0 @@ -using Zygote: pullback -using ProximalCore - -struct ZygoteFunction{F} - f::F -end - -(f::ZygoteFunction)(x) = f.f(x) - -function ProximalCore.gradient!(grad, f::ZygoteFunction, x) - fx, pb = pullback(f.f, x) - grad .= pb(one(fx))[1] - return fx -end diff --git a/src/utilities/fb_tools.jl b/src/utilities/fb_tools.jl index 34d6980..f49cebf 100644 --- a/src/utilities/fb_tools.jl +++ b/src/utilities/fb_tools.jl @@ -7,29 +7,31 @@ end function lower_bound_smoothness_constant(f, A, x, grad_f_Ax) R = real(eltype(x)) xeps = x .+ 1 - grad_f_Axeps, _ = gradient(f, A * xeps) + _, pb = eval_with_pullback(f, A * xeps) + grad_f_Axeps = pb() return norm(A' * (grad_f_Axeps - grad_f_Ax)) / R(sqrt(length(x))) end function lower_bound_smoothness_constant(f, A, x) Ax = A * x - grad_f_Ax, _ = gradient(f, Ax) + _, pb = eval_with_pullback(f, Ax) + grad_f_Ax = pb() return lower_bound_smoothness_constant(f, A, x, grad_f_Ax) end _mul!(y, L, x) = mul!(y, L, x) _mul!(y, ::Nothing, x) = return -_gradient!(y, f, x) = gradient!(y, f, x) -_gradient!(::Nothing, f, x) = f(x) - function backtrack_stepsize!( gamma::R, f, A, g, x, f_Ax::R, At_grad_f_Ax, y, z, g_z::R, res, Az, grad_f_Az=nothing; alpha = 1, minimum_gamma = 1e-7 ) where R f_Az_upp = f_model(f_Ax, At_grad_f_Ax, res, alpha / gamma) _mul!(Az, A, z) - f_Az = _gradient!(grad_f_Az, f, Az) + f_Az, pb = eval_with_pullback(f, Az) + if grad_f_Az !== nothing + grad_f_Az .= pb() + end tol = 10 * eps(R) * (1 + abs(f_Az)) while f_Az > f_Az_upp + tol && gamma >= minimum_gamma gamma /= 2 @@ -38,7 +40,10 @@ function backtrack_stepsize!( res .= x .- z f_Az_upp = f_model(f_Ax, At_grad_f_Ax, res, alpha / gamma) _mul!(Az, A, z) - f_Az = _gradient!(grad_f_Az, f, Az) + f_Az, pb = eval_with_pullback(f, Az) + if grad_f_Az !== nothing + grad_f_Az .= pb() + end tol = 10 * eps(R) * (1 + abs(f_Az)) end if gamma < minimum_gamma @@ -51,7 +56,8 @@ function backtrack_stepsize!( gamma, f, A, g, x; alpha = 1, minimum_gamma = 1e-7 ) Ax = A * x - grad_f_Ax, f_Ax = gradient(f, Ax) + f_Ax, pb = eval_with_pullback(f, Ax) + grad_f_Ax = pb() At_grad_f_Ax = A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax z, g_z = prox(g, y, gamma) diff --git a/test/Project.toml b/test/Project.toml index ea59c93..10697ef 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,3 +8,4 @@ ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/utilities/test_ad.jl b/test/utilities/test_ad.jl index a4fa27f..4f9d07b 100644 --- a/test/utilities/test_ad.jl +++ b/test/utilities/test_ad.jl @@ -2,6 +2,7 @@ using Test using LinearAlgebra using ProximalOperators: NormL1 using ProximalAlgorithms +using Zygote @testset "Autodiff ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64] R = real(T) From 7a32038331f9304d5d453086a68b200b4e0a7931 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 1 Jan 2024 12:25:40 +0100 Subject: [PATCH 02/25] try with AbstractDifferentiation (wip) --- Project.toml | 8 +++----- ext/ProximalAlgorithmsZygoteExt.jl | 18 ------------------ src/ProximalAlgorithms.jl | 11 +++++++++-- src/algorithms/davis_yin.jl | 8 ++++---- src/algorithms/fast_forward_backward.jl | 8 ++++---- src/algorithms/forward_backward.jl | 8 ++++---- src/algorithms/li_lin.jl | 12 ++++++------ src/algorithms/panoc.jl | 16 ++++++++-------- src/algorithms/panocplus.jl | 16 ++++++++-------- src/algorithms/primal_dual.jl | 8 ++++---- src/algorithms/sfista.jl | 8 ++++---- src/algorithms/zerofpr.jl | 12 ++++++------ src/utilities/fb_tools.jl | 20 ++++++++++---------- test/Project.toml | 1 + test/runtests.jl | 6 +++--- test/utilities/test_ad.jl | 14 +++----------- 16 files changed, 77 insertions(+), 97 deletions(-) delete mode 100644 ext/ProximalAlgorithmsZygoteExt.jl diff --git a/Project.toml b/Project.toml index 5963e2b..dca17e2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,21 +1,19 @@ name = "ProximalAlgorithms" uuid = "140ffc9f-1907-541a-a177-7475e0a401e9" -version = "0.5.5" +version = "0.6.0" [deps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" -[weakdeps] -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - [extensions] ProximalAlgorithmsZygoteExt = "Zygote" [compat] +AbstractDifferentiation = "0.6" LinearAlgebra = "1.2" Printf = "1.2" ProximalCore = "0.1" -Zygote = "0.6" julia = "1.2" diff --git a/ext/ProximalAlgorithmsZygoteExt.jl b/ext/ProximalAlgorithmsZygoteExt.jl deleted file mode 100644 index c5f466c..0000000 --- a/ext/ProximalAlgorithmsZygoteExt.jl +++ /dev/null @@ -1,18 +0,0 @@ -module ProximalAlgorithmsZygoteExt - -using ProximalAlgorithms -using Zygote: pullback - -struct ZygoteFunction{F} - f::F -end - -(f::ZygoteFunction)(x) = f.f(x) - -function ProximalAlgorithms.eval_with_pullback(f::ZygoteFunction, x) - out, pb = pullback(f, x) - zygote_pullback() = pb(one(out))[1] - return out, zygote_pullback -end - -end diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index 0bd51cc..51c7ab5 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -1,14 +1,21 @@ module ProximalAlgorithms +using AbstractDifferentiation: value_and_pullback_function using ProximalCore using ProximalCore: prox, prox! const RealOrComplex{R} = Union{R,Complex{R}} const Maybe{T} = Union{T,Nothing} -# gradient computation +_ad_backend = nothing -eval_with_pullback(f::F, x::X) where {F, X} = error("undefined `eval_with_pullback` for function type $F on argument of type $X") +ad_backend() = _ad_backend + +function ad_backend(backend) + global _ad_backend + _ad_backend = backend + _ad_backend +end # various utilities diff --git a/src/algorithms/davis_yin.jl b/src/algorithms/davis_yin.jl index baf7be8..175a04c 100644 --- a/src/algorithms/davis_yin.jl +++ b/src/algorithms/davis_yin.jl @@ -55,8 +55,8 @@ end function Base.iterate(iter::DavisYinIteration) z = copy(iter.x0) xg, = prox(iter.g, z, iter.gamma) - _, pb = eval_with_pullback(iter.f, xg) - grad_f_xg = pb() + f_xg, pb = value_and_pullback_function(ad_backend(), iter.f, xg) + grad_f_xg = pb(one(f_xg)) z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg xh, = prox(iter.h, z_half, iter.gamma) res = xh - xg @@ -67,8 +67,8 @@ end function Base.iterate(iter::DavisYinIteration, state::DavisYinState) prox!(state.xg, iter.g, state.z, iter.gamma) - _, pb = eval_with_pullback(iter.f, state.xg) - state.grad_f_xg .= pb() + f_xg, pb = value_and_pullback_function(ad_backend(), iter.f, state.xg) + state.grad_f_xg .= pb(one(f_xg)) state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg prox!(state.xh, iter.h, state.z_half, iter.gamma) state.res .= state.xh .- state.xg diff --git a/src/algorithms/fast_forward_backward.jl b/src/algorithms/fast_forward_backward.jl index bd96a21..84a958b 100644 --- a/src/algorithms/fast_forward_backward.jl +++ b/src/algorithms/fast_forward_backward.jl @@ -68,8 +68,8 @@ end function Base.iterate(iter::FastForwardBackwardIteration) x = copy(iter.x0) - f_x, pb = eval_with_pullback(iter.f, x) - grad_f_x = pb() + f_x, pb = value_and_pullback_function(ad_backend(), iter.f, x) + grad_f_x = pb(one(f_x)) gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x z, g_z = prox(iter.g, y, gamma) @@ -104,8 +104,8 @@ function Base.iterate(iter::FastForwardBackwardIteration{R}, state::FastForwardB state.x .= state.z .+ beta .* (state.z .- state.z_prev) state.z_prev, state.z = state.z, state.z_prev - state.f_x, pb = eval_with_pullback(iter.f, state.x) - state.grad_f_x .= pb() + state.f_x, pb = value_and_pullback_function(ad_backend(), iter.f, state.x) + state.grad_f_x .= pb(one(state.f_x)) state.y .= state.x .- state.gamma .* state.grad_f_x state.g_z = prox!(state.z, iter.g, state.y, state.gamma) state.res .= state.x .- state.z diff --git a/src/algorithms/forward_backward.jl b/src/algorithms/forward_backward.jl index e651b58..f1dc26c 100644 --- a/src/algorithms/forward_backward.jl +++ b/src/algorithms/forward_backward.jl @@ -59,8 +59,8 @@ end function Base.iterate(iter::ForwardBackwardIteration) x = copy(iter.x0) - f_x, pb = eval_with_pullback(iter.f, x) - grad_f_x = pb() + f_x, pb = value_and_pullback_function(ad_backend(), iter.f, x) + grad_f_x = pb(one(f_x)) gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x z, g_z = prox(iter.g, y, gamma) @@ -82,8 +82,8 @@ function Base.iterate(iter::ForwardBackwardIteration{R}, state::ForwardBackwardS state.grad_f_x, state.grad_f_z = state.grad_f_z, state.grad_f_x else state.x, state.z = state.z, state.x - state.f_x, pb = eval_with_pullback(iter.f, state.x) - state.grad_f_x .= pb() + state.f_x, pb = value_and_pullback_function(ad_backend(), iter.f, state.x) + state.grad_f_x .= pb(one(state.f_x)) end state.y .= state.x .- state.gamma .* state.grad_f_x diff --git a/src/algorithms/li_lin.jl b/src/algorithms/li_lin.jl index 3be37c7..47a676f 100644 --- a/src/algorithms/li_lin.jl +++ b/src/algorithms/li_lin.jl @@ -62,8 +62,8 @@ end function Base.iterate(iter::LiLinIteration{R}) where {R} y = copy(iter.x0) - f_y, pb = eval_with_pullback(iter.f, y) - grad_f_y = pb() + f_y, pb = value_and_pullback_function(ad_backend(), iter.f, y) + grad_f_y = pb(one(f_y)) # TODO: initialize gamma if not provided # TODO: authors suggest Barzilai-Borwein rule? @@ -103,8 +103,8 @@ function Base.iterate( else # TODO: re-use available space in state? # TODO: backtrack gamma at x - f_x, pb = eval_with_pullback(iter.f, x) - grad_f_x = pb() + f_x, pb = value_and_pullback_function(ad_backend(), iter.f, x) + grad_f_x = pb(one(f_x)) x_forward = state.x - state.gamma .* grad_f_x v, g_v = prox(iter.g, x_forward, state.gamma) Fv = iter.f(v) + g_v @@ -123,8 +123,8 @@ function Base.iterate( Fx = Fv end - state.f_y, pb = eval_with_pullback(iter.f, state.y) - state.grad_f_y .= pb() + state.f_y, pb = value_and_pullback_function(ad_backend(), iter.f, state.y) + state.grad_f_y .= pb(one(state.f_y)) state.y_forward .= state.y .- state.gamma .* state.grad_f_y state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma) diff --git a/src/algorithms/panoc.jl b/src/algorithms/panoc.jl index 78eb7ff..fb2b6b9 100644 --- a/src/algorithms/panoc.jl +++ b/src/algorithms/panoc.jl @@ -86,8 +86,8 @@ f_model(iter::PANOCIteration, state::PANOCState) = f_model(state.f_Ax, state.At_ function Base.iterate(iter::PANOCIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = eval_with_pullback(iter.f, Ax) - grad_f_Ax = pb() + f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, Ax) + grad_f_Ax = pb(one(f_Ax)) gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -153,8 +153,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where state.x_d .= state.x .+ state.d state.Ax_d .= state.Ax .+ state.Ad - state.f_Ax_d, pb = eval_with_pullback(iter.f, state.Ax_d) - state.grad_f_Ax_d .= pb() + state.f_Ax_d, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax_d) + state.grad_f_Ax_d .= pb(one(state.f_Ax_d)) mul!(state.At_grad_f_Ax_d, adjoint(iter.A), state.grad_f_Ax_d) copyto!(state.x, state.x_d) @@ -191,8 +191,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where # along a line using interpolation and linear combinations # this allows saving operations if isinf(f_Az) - f_Az, pb = eval_with_pullback(iter.f, state.Az) - state.grad_f_Az .= pb() + f_Az, pb = value_and_pullback_function(ad_backend(), iter.f, state.Az) + state.grad_f_Az .= pb(one(f_Az)) end if isinf(c) mul!(state.At_grad_f_Az, iter.A', state.grad_f_Az) @@ -206,8 +206,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where else # otherwise, in the general case where f is only smooth, we compute # one gradient and matvec per backtracking step - state.f_Ax, pb = eval_with_pullback(iter.f, state.Ax) - state.grad_f_Ax .= pb() + state.f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax) + state.grad_f_Ax .= pb(one(state.f_Ax)) mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) end diff --git a/src/algorithms/panocplus.jl b/src/algorithms/panocplus.jl index 5d113c4..c073d16 100644 --- a/src/algorithms/panocplus.jl +++ b/src/algorithms/panocplus.jl @@ -79,8 +79,8 @@ f_model(iter::PANOCplusIteration, state::PANOCplusState) = f_model(state.f_Ax, s function Base.iterate(iter::PANOCplusIteration{R}) where {R} x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = eval_with_pullback(iter.f, Ax) - grad_f_Ax = pb() + f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, Ax) + grad_f_Ax = pb(one(f_Ax)) gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -98,8 +98,8 @@ function Base.iterate(iter::PANOCplusIteration{R}) where {R} ) else mul!(state.Az, iter.A, state.z) - _, pb = eval_with_pullback(iter.f, state.Az) - state.grad_f_Az = pb() + f_Az, pb = value_and_pullback_function(ad_backend(), iter.f, state.Az) + state.grad_f_Az = pb(one(f_Az)) end mul!(state.At_grad_f_Az, adjoint(iter.A), state.grad_f_Az) return state, state @@ -154,8 +154,8 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where end mul!(state.Ax, iter.A, state.x) - state.f_Ax, pb = eval_with_pullback(iter.f, state.Ax) - state.grad_f_Ax .= pb() + state.f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax) + state.grad_f_Ax .= pb(one(state.f_Ax)) mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) state.y .= state.x .- state.gamma .* state.At_grad_f_Ax @@ -165,8 +165,8 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where f_Az_upp = f_model(iter, state) mul!(state.Az, iter.A, state.z) - f_Az, pb = eval_with_pullback(iter.f, state.Az) - state.grad_f_Az .= pb() + f_Az, pb = value_and_pullback_function(ad_backend(), iter.f, state.Az) + state.grad_f_Az .= pb(one(f_Az)) if (iter.gamma === nothing || iter.adaptive == true) tol = 10 * eps(R) * (1 + abs(f_Az)) if f_Az > f_Az_upp + tol && state.gamma >= iter.minimum_gamma diff --git a/src/algorithms/primal_dual.jl b/src/algorithms/primal_dual.jl index 9a4da4e..61993bd 100644 --- a/src/algorithms/primal_dual.jl +++ b/src/algorithms/primal_dual.jl @@ -167,8 +167,8 @@ end function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(iter.x0), y=copy(iter.y0))) # perform xbar-update step - _, pb = eval_with_pullback(iter.f, state.x) - state.gradf .= pb() + f_x, pb = value_and_pullback_function(ad_backend(), iter.f, state.x) + state.gradf .= pb(one(f_x)) mul!(state.temp_x, iter.L', state.y) state.temp_x .+= state.gradf state.temp_x .*= -iter.gamma[1] @@ -176,8 +176,8 @@ function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(i prox!(state.xbar, iter.g, state.temp_x, iter.gamma[1]) # perform ybar-update step - _, pb = eval_with_pullback(convex_conjugate(iter.l), state.y) - state.gradl .= pb() + lc_y, pb = value_and_pullback_function(ad_backend(), convex_conjugate(iter.l), state.y) + state.gradl .= pb(one(lc_y)) state.temp_x .= iter.theta .* state.xbar .+ (1 - iter.theta) .* state.x mul!(state.temp_y, iter.L, state.temp_x) state.temp_y .-= state.gradl diff --git a/src/algorithms/sfista.jl b/src/algorithms/sfista.jl index 05c041d..2be4660 100644 --- a/src/algorithms/sfista.jl +++ b/src/algorithms/sfista.jl @@ -71,8 +71,8 @@ function Base.iterate( state.a = (state.τ + sqrt(state.τ ^ 2 + 4 * state.τ * state.APrev)) / 2 state.A = state.APrev + state.a state.xt .= (state.APrev / state.A) .* state.yPrev + (state.a / state.A) .* state.xPrev - _, pb = eval_with_pullback(iter.f, state.xt) - state.gradf_xt .= pb() + f_xt, pb = value_and_pullback_function(ad_backend(), iter.f, state.xt) + state.gradf_xt .= pb(one(f_xt)) λ2 = state.λ / (1 + state.λ * iter.mf) # FISTA acceleration steps. prox!(state.y, iter.g, state.xt - λ2 * state.gradf_xt, λ2) @@ -94,8 +94,8 @@ function check_sc(state::SFISTAState, iter::SFISTAIteration, tol, termination_ty else # Classic (approximate) first-order stationary point [4]. The main inclusion is: r ∈ ∇f(y) + ∂h(y). λ2 = state.λ / (1 + state.λ * iter.mf) - _, pb = eval_with_pullback(iter.f, state.y) - gradf_y = pb() + f_y, pb = value_and_pullback_function(ad_backend(), iter.f, state.y) + gradf_y = pb(one(f_y)) r = gradf_y - state.gradf_xt + (state.xt - state.y) / λ2 res = norm(r) end diff --git a/src/algorithms/zerofpr.jl b/src/algorithms/zerofpr.jl index 7d5de61..5047368 100644 --- a/src/algorithms/zerofpr.jl +++ b/src/algorithms/zerofpr.jl @@ -84,8 +84,8 @@ f_model(iter::ZeroFPRIteration, state::ZeroFPRState) = f_model(state.f_Ax, state function Base.iterate(iter::ZeroFPRIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = eval_with_pullback(iter.f, Ax) - grad_f_Ax = pb() + f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, Ax) + grad_f_Ax = pb(one(f_Ax)) gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -131,8 +131,8 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R f_Axbar_upp, f_Axbar else mul!(state.Axbar, iter.A, state.xbar) - f_Axbar, pb = eval_with_pullback(iter.f, state.Axbar) - state.grad_f_Axbar .= pb() + f_Axbar, pb = value_and_pullback_function(ad_backend(), iter.f, state.Axbar) + state.grad_f_Axbar .= pb(one(f_Axbar)) f_model(iter, state), f_Axbar end @@ -167,8 +167,8 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R state.x .= state.xbar_prev .+ state.tau .* state.d state.Ax .= state.Axbar .+ state.tau .* state.Ad # TODO: can precompute most of next line in case f is quadratic - state.f_Ax, pb = eval_with_pullback(iter.f, state.Ax) - state.grad_f_Ax .= pb() + state.f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax) + state.grad_f_Ax .= pb(one(state.f_Ax)) mul!(state.At_grad_f_Ax, iter.A', state.grad_f_Ax) state.y .= state.x .- state.gamma .* state.At_grad_f_Ax state.g_xbar = prox!(state.xbar, iter.g, state.y, state.gamma) diff --git a/src/utilities/fb_tools.jl b/src/utilities/fb_tools.jl index f49cebf..7e537d9 100644 --- a/src/utilities/fb_tools.jl +++ b/src/utilities/fb_tools.jl @@ -7,15 +7,15 @@ end function lower_bound_smoothness_constant(f, A, x, grad_f_Ax) R = real(eltype(x)) xeps = x .+ 1 - _, pb = eval_with_pullback(f, A * xeps) - grad_f_Axeps = pb() + f_Axeps, pb = value_and_pullback_function(ad_backend(), f, A * xeps) + grad_f_Axeps = pb(one(f_Axeps)) return norm(A' * (grad_f_Axeps - grad_f_Ax)) / R(sqrt(length(x))) end function lower_bound_smoothness_constant(f, A, x) Ax = A * x - _, pb = eval_with_pullback(f, Ax) - grad_f_Ax = pb() + f_Ax, pb = value_and_pullback_function(ad_backend(), f, Ax) + grad_f_Ax = pb(one(f_Ax)) return lower_bound_smoothness_constant(f, A, x, grad_f_Ax) end @@ -28,9 +28,9 @@ function backtrack_stepsize!( ) where R f_Az_upp = f_model(f_Ax, At_grad_f_Ax, res, alpha / gamma) _mul!(Az, A, z) - f_Az, pb = eval_with_pullback(f, Az) + f_Az, pb = value_and_pullback_function(ad_backend(), f, Az) if grad_f_Az !== nothing - grad_f_Az .= pb() + grad_f_Az .= pb(one(f_Az)) end tol = 10 * eps(R) * (1 + abs(f_Az)) while f_Az > f_Az_upp + tol && gamma >= minimum_gamma @@ -40,9 +40,9 @@ function backtrack_stepsize!( res .= x .- z f_Az_upp = f_model(f_Ax, At_grad_f_Ax, res, alpha / gamma) _mul!(Az, A, z) - f_Az, pb = eval_with_pullback(f, Az) + f_Az, pb = value_and_pullback_function(ad_backend(), f, Az) if grad_f_Az !== nothing - grad_f_Az .= pb() + grad_f_Az .= pb(one(f_Az)) end tol = 10 * eps(R) * (1 + abs(f_Az)) end @@ -56,8 +56,8 @@ function backtrack_stepsize!( gamma, f, A, g, x; alpha = 1, minimum_gamma = 1e-7 ) Ax = A * x - f_Ax, pb = eval_with_pullback(f, Ax) - grad_f_Ax = pb() + f_Ax, pb = value_and_pullback_function(ad_backend(), f, Ax) + grad_f_Ax = pb(one(f_Ax)) At_grad_f_Ax = A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax z, g_z = prox(g, y, gamma) diff --git a/test/Project.toml b/test/Project.toml index 10697ef..d1811fe 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/runtests.jl b/test/runtests.jl index 6532278..ebe17bf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,9 +2,9 @@ using Test using Aqua using ProximalAlgorithms -@testset "Aqua" begin - Aqua.test_all(ProximalAlgorithms; ambiguities=false) -end +# @testset "Aqua" begin +# Aqua.test_all(ProximalAlgorithms; ambiguities=false) +# end include("definitions/arraypartition.jl") include("definitions/compose.jl") diff --git a/test/utilities/test_ad.jl b/test/utilities/test_ad.jl index 4f9d07b..10aca78 100644 --- a/test/utilities/test_ad.jl +++ b/test/utilities/test_ad.jl @@ -2,7 +2,7 @@ using Test using LinearAlgebra using ProximalOperators: NormL1 using ProximalAlgorithms -using Zygote +using AbstractDifferentiation: ZygoteBackend @testset "Autodiff ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64] R = real(T) @@ -13,19 +13,11 @@ using Zygote -1.0 -1.0 -1.0 1.0 3.0 ] b = T[1.0, 2.0, 3.0, 4.0] - f = ProximalAlgorithms.ZygoteFunction( - x -> R(1/2) * norm(A * x - b, 2)^2 - ) + f = x -> R(1/2) * norm(A * x - b, 2)^2 Lf = opnorm(A)^2 m, n = size(A) - @testset "Gradient" begin - x = randn(T, n) - gradfx, fx = ProximalAlgorithms.gradient(f, x) - @test eltype(gradfx) == T - @test typeof(fx) == R - @test gradfx ≈ A' * (A * x - b) - end + ProximalAlgorithms.ad_backend(ZygoteBackend()) @testset "Algorithms" begin lam = R(0.1) * norm(A' * b, Inf) From 4a2bbb8819d1a5a09cd367bebcd27a75a699a81b Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Tue, 2 Jan 2024 16:27:53 +0100 Subject: [PATCH 03/25] fixes --- Project.toml | 3 --- src/ProximalAlgorithms.jl | 4 +++- src/algorithms/davis_yin.jl | 4 ++-- src/algorithms/fast_forward_backward.jl | 4 ++-- src/algorithms/forward_backward.jl | 4 ++-- src/algorithms/li_lin.jl | 6 +++--- src/algorithms/panoc.jl | 8 ++++---- src/algorithms/panocplus.jl | 8 ++++---- src/algorithms/primal_dual.jl | 4 ++-- src/algorithms/sfista.jl | 4 ++-- src/algorithms/zerofpr.jl | 6 +++--- src/utilities/fb_tools.jl | 11 ++++++----- test/runtests.jl | 6 +++--- test/utilities/test_ad.jl | 1 + test/utilities/test_fb_tools.jl | 4 ++-- 15 files changed, 39 insertions(+), 38 deletions(-) diff --git a/Project.toml b/Project.toml index dca17e2..e8e7aab 100644 --- a/Project.toml +++ b/Project.toml @@ -8,9 +8,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" -[extensions] -ProximalAlgorithmsZygoteExt = "Zygote" - [compat] AbstractDifferentiation = "0.6" LinearAlgebra = "1.2" diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index 51c7ab5..e1b13a2 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -9,7 +9,9 @@ const Maybe{T} = Union{T,Nothing} _ad_backend = nothing -ad_backend() = _ad_backend +function ad_backend() + _ad_backend +end function ad_backend(backend) global _ad_backend diff --git a/src/algorithms/davis_yin.jl b/src/algorithms/davis_yin.jl index 175a04c..fea2af6 100644 --- a/src/algorithms/davis_yin.jl +++ b/src/algorithms/davis_yin.jl @@ -56,7 +56,7 @@ function Base.iterate(iter::DavisYinIteration) z = copy(iter.x0) xg, = prox(iter.g, z, iter.gamma) f_xg, pb = value_and_pullback_function(ad_backend(), iter.f, xg) - grad_f_xg = pb(one(f_xg)) + grad_f_xg = pb(one(f_xg))[1] z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg xh, = prox(iter.h, z_half, iter.gamma) res = xh - xg @@ -68,7 +68,7 @@ end function Base.iterate(iter::DavisYinIteration, state::DavisYinState) prox!(state.xg, iter.g, state.z, iter.gamma) f_xg, pb = value_and_pullback_function(ad_backend(), iter.f, state.xg) - state.grad_f_xg .= pb(one(f_xg)) + state.grad_f_xg .= pb(one(f_xg))[1] state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg prox!(state.xh, iter.h, state.z_half, iter.gamma) state.res .= state.xh .- state.xg diff --git a/src/algorithms/fast_forward_backward.jl b/src/algorithms/fast_forward_backward.jl index 84a958b..752642f 100644 --- a/src/algorithms/fast_forward_backward.jl +++ b/src/algorithms/fast_forward_backward.jl @@ -69,7 +69,7 @@ end function Base.iterate(iter::FastForwardBackwardIteration) x = copy(iter.x0) f_x, pb = value_and_pullback_function(ad_backend(), iter.f, x) - grad_f_x = pb(one(f_x)) + grad_f_x = pb(one(f_x))[1] gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x z, g_z = prox(iter.g, y, gamma) @@ -105,7 +105,7 @@ function Base.iterate(iter::FastForwardBackwardIteration{R}, state::FastForwardB state.z_prev, state.z = state.z, state.z_prev state.f_x, pb = value_and_pullback_function(ad_backend(), iter.f, state.x) - state.grad_f_x .= pb(one(state.f_x)) + state.grad_f_x .= pb(one(state.f_x))[1] state.y .= state.x .- state.gamma .* state.grad_f_x state.g_z = prox!(state.z, iter.g, state.y, state.gamma) state.res .= state.x .- state.z diff --git a/src/algorithms/forward_backward.jl b/src/algorithms/forward_backward.jl index f1dc26c..9516360 100644 --- a/src/algorithms/forward_backward.jl +++ b/src/algorithms/forward_backward.jl @@ -60,7 +60,7 @@ end function Base.iterate(iter::ForwardBackwardIteration) x = copy(iter.x0) f_x, pb = value_and_pullback_function(ad_backend(), iter.f, x) - grad_f_x = pb(one(f_x)) + grad_f_x = pb(one(f_x))[1] gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x z, g_z = prox(iter.g, y, gamma) @@ -83,7 +83,7 @@ function Base.iterate(iter::ForwardBackwardIteration{R}, state::ForwardBackwardS else state.x, state.z = state.z, state.x state.f_x, pb = value_and_pullback_function(ad_backend(), iter.f, state.x) - state.grad_f_x .= pb(one(state.f_x)) + state.grad_f_x .= pb(one(state.f_x))[1] end state.y .= state.x .- state.gamma .* state.grad_f_x diff --git a/src/algorithms/li_lin.jl b/src/algorithms/li_lin.jl index 47a676f..ebba028 100644 --- a/src/algorithms/li_lin.jl +++ b/src/algorithms/li_lin.jl @@ -63,7 +63,7 @@ end function Base.iterate(iter::LiLinIteration{R}) where {R} y = copy(iter.x0) f_y, pb = value_and_pullback_function(ad_backend(), iter.f, y) - grad_f_y = pb(one(f_y)) + grad_f_y = pb(one(f_y))[1] # TODO: initialize gamma if not provided # TODO: authors suggest Barzilai-Borwein rule? @@ -104,7 +104,7 @@ function Base.iterate( # TODO: re-use available space in state? # TODO: backtrack gamma at x f_x, pb = value_and_pullback_function(ad_backend(), iter.f, x) - grad_f_x = pb(one(f_x)) + grad_f_x = pb(one(f_x))[1] x_forward = state.x - state.gamma .* grad_f_x v, g_v = prox(iter.g, x_forward, state.gamma) Fv = iter.f(v) + g_v @@ -124,7 +124,7 @@ function Base.iterate( end state.f_y, pb = value_and_pullback_function(ad_backend(), iter.f, state.y) - state.grad_f_y .= pb(one(state.f_y)) + state.grad_f_y .= pb(one(state.f_y))[1] state.y_forward .= state.y .- state.gamma .* state.grad_f_y state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma) diff --git a/src/algorithms/panoc.jl b/src/algorithms/panoc.jl index fb2b6b9..e4ff657 100644 --- a/src/algorithms/panoc.jl +++ b/src/algorithms/panoc.jl @@ -87,7 +87,7 @@ function Base.iterate(iter::PANOCIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, Ax) - grad_f_Ax = pb(one(f_Ax)) + grad_f_Ax = pb(one(f_Ax))[1] gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -154,7 +154,7 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where state.x_d .= state.x .+ state.d state.Ax_d .= state.Ax .+ state.Ad state.f_Ax_d, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax_d) - state.grad_f_Ax_d .= pb(one(state.f_Ax_d)) + state.grad_f_Ax_d .= pb(one(state.f_Ax_d))[1] mul!(state.At_grad_f_Ax_d, adjoint(iter.A), state.grad_f_Ax_d) copyto!(state.x, state.x_d) @@ -192,7 +192,7 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where # this allows saving operations if isinf(f_Az) f_Az, pb = value_and_pullback_function(ad_backend(), iter.f, state.Az) - state.grad_f_Az .= pb(one(f_Az)) + state.grad_f_Az .= pb(one(f_Az))[1] end if isinf(c) mul!(state.At_grad_f_Az, iter.A', state.grad_f_Az) @@ -207,7 +207,7 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where # otherwise, in the general case where f is only smooth, we compute # one gradient and matvec per backtracking step state.f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax) - state.grad_f_Ax .= pb(one(state.f_Ax)) + state.grad_f_Ax .= pb(one(state.f_Ax))[1] mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) end diff --git a/src/algorithms/panocplus.jl b/src/algorithms/panocplus.jl index c073d16..4031f26 100644 --- a/src/algorithms/panocplus.jl +++ b/src/algorithms/panocplus.jl @@ -80,7 +80,7 @@ function Base.iterate(iter::PANOCplusIteration{R}) where {R} x = copy(iter.x0) Ax = iter.A * x f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, Ax) - grad_f_Ax = pb(one(f_Ax)) + grad_f_Ax = pb(one(f_Ax))[1] gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -99,7 +99,7 @@ function Base.iterate(iter::PANOCplusIteration{R}) where {R} else mul!(state.Az, iter.A, state.z) f_Az, pb = value_and_pullback_function(ad_backend(), iter.f, state.Az) - state.grad_f_Az = pb(one(f_Az)) + state.grad_f_Az = pb(one(f_Az))[1] end mul!(state.At_grad_f_Az, adjoint(iter.A), state.grad_f_Az) return state, state @@ -155,7 +155,7 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where mul!(state.Ax, iter.A, state.x) state.f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax) - state.grad_f_Ax .= pb(one(state.f_Ax)) + state.grad_f_Ax .= pb(one(state.f_Ax))[1] mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) state.y .= state.x .- state.gamma .* state.At_grad_f_Ax @@ -166,7 +166,7 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where mul!(state.Az, iter.A, state.z) f_Az, pb = value_and_pullback_function(ad_backend(), iter.f, state.Az) - state.grad_f_Az .= pb(one(f_Az)) + state.grad_f_Az .= pb(one(f_Az))[1] if (iter.gamma === nothing || iter.adaptive == true) tol = 10 * eps(R) * (1 + abs(f_Az)) if f_Az > f_Az_upp + tol && state.gamma >= iter.minimum_gamma diff --git a/src/algorithms/primal_dual.jl b/src/algorithms/primal_dual.jl index 61993bd..b9f6fa4 100644 --- a/src/algorithms/primal_dual.jl +++ b/src/algorithms/primal_dual.jl @@ -168,7 +168,7 @@ end function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(iter.x0), y=copy(iter.y0))) # perform xbar-update step f_x, pb = value_and_pullback_function(ad_backend(), iter.f, state.x) - state.gradf .= pb(one(f_x)) + state.gradf .= pb(one(f_x))[1] mul!(state.temp_x, iter.L', state.y) state.temp_x .+= state.gradf state.temp_x .*= -iter.gamma[1] @@ -177,7 +177,7 @@ function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(i # perform ybar-update step lc_y, pb = value_and_pullback_function(ad_backend(), convex_conjugate(iter.l), state.y) - state.gradl .= pb(one(lc_y)) + state.gradl .= pb(one(lc_y))[1] state.temp_x .= iter.theta .* state.xbar .+ (1 - iter.theta) .* state.x mul!(state.temp_y, iter.L, state.temp_x) state.temp_y .-= state.gradl diff --git a/src/algorithms/sfista.jl b/src/algorithms/sfista.jl index 2be4660..55009f2 100644 --- a/src/algorithms/sfista.jl +++ b/src/algorithms/sfista.jl @@ -72,7 +72,7 @@ function Base.iterate( state.A = state.APrev + state.a state.xt .= (state.APrev / state.A) .* state.yPrev + (state.a / state.A) .* state.xPrev f_xt, pb = value_and_pullback_function(ad_backend(), iter.f, state.xt) - state.gradf_xt .= pb(one(f_xt)) + state.gradf_xt .= pb(one(f_xt))[1] λ2 = state.λ / (1 + state.λ * iter.mf) # FISTA acceleration steps. prox!(state.y, iter.g, state.xt - λ2 * state.gradf_xt, λ2) @@ -95,7 +95,7 @@ function check_sc(state::SFISTAState, iter::SFISTAIteration, tol, termination_ty # Classic (approximate) first-order stationary point [4]. The main inclusion is: r ∈ ∇f(y) + ∂h(y). λ2 = state.λ / (1 + state.λ * iter.mf) f_y, pb = value_and_pullback_function(ad_backend(), iter.f, state.y) - gradf_y = pb(one(f_y)) + gradf_y = pb(one(f_y))[1] r = gradf_y - state.gradf_xt + (state.xt - state.y) / λ2 res = norm(r) end diff --git a/src/algorithms/zerofpr.jl b/src/algorithms/zerofpr.jl index 5047368..2240a48 100644 --- a/src/algorithms/zerofpr.jl +++ b/src/algorithms/zerofpr.jl @@ -85,7 +85,7 @@ function Base.iterate(iter::ZeroFPRIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, Ax) - grad_f_Ax = pb(one(f_Ax)) + grad_f_Ax = pb(one(f_Ax))[1] gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -132,7 +132,7 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R else mul!(state.Axbar, iter.A, state.xbar) f_Axbar, pb = value_and_pullback_function(ad_backend(), iter.f, state.Axbar) - state.grad_f_Axbar .= pb(one(f_Axbar)) + state.grad_f_Axbar .= pb(one(f_Axbar))[1] f_model(iter, state), f_Axbar end @@ -168,7 +168,7 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R state.Ax .= state.Axbar .+ state.tau .* state.Ad # TODO: can precompute most of next line in case f is quadratic state.f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax) - state.grad_f_Ax .= pb(one(state.f_Ax)) + state.grad_f_Ax .= pb(one(state.f_Ax))[1] mul!(state.At_grad_f_Ax, iter.A', state.grad_f_Ax) state.y .= state.x .- state.gamma .* state.At_grad_f_Ax state.g_xbar = prox!(state.xbar, iter.g, state.y, state.gamma) diff --git a/src/utilities/fb_tools.jl b/src/utilities/fb_tools.jl index 7e537d9..83e9a78 100644 --- a/src/utilities/fb_tools.jl +++ b/src/utilities/fb_tools.jl @@ -8,14 +8,15 @@ function lower_bound_smoothness_constant(f, A, x, grad_f_Ax) R = real(eltype(x)) xeps = x .+ 1 f_Axeps, pb = value_and_pullback_function(ad_backend(), f, A * xeps) - grad_f_Axeps = pb(one(f_Axeps)) + grad_f_Axeps = pb(one(R))[1] return norm(A' * (grad_f_Axeps - grad_f_Ax)) / R(sqrt(length(x))) end function lower_bound_smoothness_constant(f, A, x) + R = real(eltype(x)) Ax = A * x f_Ax, pb = value_and_pullback_function(ad_backend(), f, Ax) - grad_f_Ax = pb(one(f_Ax)) + grad_f_Ax = pb(one(R))[1] return lower_bound_smoothness_constant(f, A, x, grad_f_Ax) end @@ -30,7 +31,7 @@ function backtrack_stepsize!( _mul!(Az, A, z) f_Az, pb = value_and_pullback_function(ad_backend(), f, Az) if grad_f_Az !== nothing - grad_f_Az .= pb(one(f_Az)) + grad_f_Az .= pb(one(f_Az))[1] end tol = 10 * eps(R) * (1 + abs(f_Az)) while f_Az > f_Az_upp + tol && gamma >= minimum_gamma @@ -42,7 +43,7 @@ function backtrack_stepsize!( _mul!(Az, A, z) f_Az, pb = value_and_pullback_function(ad_backend(), f, Az) if grad_f_Az !== nothing - grad_f_Az .= pb(one(f_Az)) + grad_f_Az .= pb(one(f_Az))[1] end tol = 10 * eps(R) * (1 + abs(f_Az)) end @@ -57,7 +58,7 @@ function backtrack_stepsize!( ) Ax = A * x f_Ax, pb = value_and_pullback_function(ad_backend(), f, Ax) - grad_f_Ax = pb(one(f_Ax)) + grad_f_Ax = pb(one(f_Ax))[1] At_grad_f_Ax = A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax z, g_z = prox(g, y, gamma) diff --git a/test/runtests.jl b/test/runtests.jl index ebe17bf..6532278 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,9 +2,9 @@ using Test using Aqua using ProximalAlgorithms -# @testset "Aqua" begin -# Aqua.test_all(ProximalAlgorithms; ambiguities=false) -# end +@testset "Aqua" begin + Aqua.test_all(ProximalAlgorithms; ambiguities=false) +end include("definitions/arraypartition.jl") include("definitions/compose.jl") diff --git a/test/utilities/test_ad.jl b/test/utilities/test_ad.jl index 10aca78..efd2d1f 100644 --- a/test/utilities/test_ad.jl +++ b/test/utilities/test_ad.jl @@ -2,6 +2,7 @@ using Test using LinearAlgebra using ProximalOperators: NormL1 using ProximalAlgorithms +using Zygote using AbstractDifferentiation: ZygoteBackend @testset "Autodiff ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64] diff --git a/test/utilities/test_fb_tools.jl b/test/utilities/test_fb_tools.jl index 75a3d50..2c5386c 100644 --- a/test/utilities/test_fb_tools.jl +++ b/test/utilities/test_fb_tools.jl @@ -1,7 +1,7 @@ using Test using LinearAlgebra +using ProximalCore: Zero using ProximalAlgorithms: lower_bound_smoothness_constant, backtrack_stepsize! -using ProximalOperators: Quadratic, Zero @testset "Lipschitz constant estimation" for R in [Float32, Float64] @@ -11,7 +11,7 @@ U, _ = qr(randn(n, n)) Q = U * Diagonal(sv) * U' q = randn(n) -f = Quadratic(Q, q) +f(x) = 0.5 * dot(x, Q * x) + dot(q, x) Lf = maximum(sv) g = Zero() From 9c15b0e6bd1f47a4cf78e26f426b0c4d22522197 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Tue, 2 Jan 2024 17:13:03 +0100 Subject: [PATCH 04/25] fixes --- test/Project.toml | 2 ++ test/utilities/test_ad.jl | 43 +++++++++++++++++++++------------ test/utilities/test_fb_tools.jl | 2 +- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index d1811fe..7241417 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,11 +2,13 @@ AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/utilities/test_ad.jl b/test/utilities/test_ad.jl index efd2d1f..1b1a47e 100644 --- a/test/utilities/test_ad.jl +++ b/test/utilities/test_ad.jl @@ -3,9 +3,19 @@ using LinearAlgebra using ProximalOperators: NormL1 using ProximalAlgorithms using Zygote -using AbstractDifferentiation: ZygoteBackend +using ReverseDiff +using ForwardDiff +using AbstractDifferentiation: value_and_pullback_function +using AbstractDifferentiation: ZygoteBackend, ReverseDiffBackend, ForwardDiffBackend + +@testset "Autodiff backend ($B on $T)" for (T, B) in Iterators.product( + [Float32, Float64, ComplexF32, ComplexF64], + [ZygoteBackend, ReverseDiffBackend, ForwardDiffBackend], +) + if T <: Complex && B in [ReverseDiffBackend, ForwardDiffBackend] + continue + end -@testset "Autodiff ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64] R = real(T) A = T[ 1.0 -2.0 3.0 -4.0 5.0 @@ -18,18 +28,21 @@ using AbstractDifferentiation: ZygoteBackend Lf = opnorm(A)^2 m, n = size(A) - ProximalAlgorithms.ad_backend(ZygoteBackend()) + x0 = zeros(T, n) - @testset "Algorithms" begin - lam = R(0.1) * norm(A' * b, Inf) - @test typeof(lam) == R - g = NormL1(lam) - x_star = T[-3.877278911564627e-01, 0, 0, 2.174149659863943e-02, 6.168435374149660e-01] - TOL = R(1e-4) - solver = ProximalAlgorithms.FastForwardBackward(tol = TOL) - x, it = solver(x0 = zeros(T, n), f = f, g = g, Lf = Lf) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 100 - end + f_x0, pb = value_and_pullback_function(B(), f, x0) + grad_f_x0 = @inferred pb(one(R))[1] + + ProximalAlgorithms.ad_backend(B()) + + lam = R(0.1) * norm(A' * b, Inf) + @test typeof(lam) == R + g = NormL1(lam) + x_star = T[-3.877278911564627e-01, 0, 0, 2.174149659863943e-02, 6.168435374149660e-01] + TOL = R(1e-4) + solver = ProximalAlgorithms.FastForwardBackward(tol = TOL) + x, it = solver(x0 = x0, f = f, g = g, Lf = Lf) + @test eltype(x) == T + @test norm(x - x_star, Inf) <= TOL + @test it < 100 end diff --git a/test/utilities/test_fb_tools.jl b/test/utilities/test_fb_tools.jl index 2c5386c..97c56ce 100644 --- a/test/utilities/test_fb_tools.jl +++ b/test/utilities/test_fb_tools.jl @@ -11,7 +11,7 @@ U, _ = qr(randn(n, n)) Q = U * Diagonal(sv) * U' q = randn(n) -f(x) = 0.5 * dot(x, Q * x) + dot(q, x) +f(x) = R(0.5) * dot(x, Q * x) + dot(q, x) Lf = maximum(sv) g = Zero() From 585e244ebe4110bf985afbe6c99914b22dca3418 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Tue, 2 Jan 2024 23:52:17 +0100 Subject: [PATCH 05/25] streamline --- src/ProximalAlgorithms.jl | 19 ++++----- src/algorithms/davis_yin.jl | 4 +- src/algorithms/fast_forward_backward.jl | 4 +- src/algorithms/forward_backward.jl | 4 +- src/algorithms/li_lin.jl | 6 +-- src/algorithms/panoc.jl | 8 ++-- src/algorithms/panocplus.jl | 8 ++-- src/algorithms/primal_dual.jl | 4 +- src/algorithms/sfista.jl | 4 +- src/algorithms/zerofpr.jl | 6 +-- src/utilities/fb_tools.jl | 10 ++--- test/problems/test_elasticnet.jl | 16 +++---- test/problems/test_equivalence.jl | 9 ++-- test/runtests.jl | 15 +++++++ test/utilities/test_ad.jl | 4 +- test/utilities/test_fb_tools.jl | 56 +++++++++++++------------ 16 files changed, 97 insertions(+), 80 deletions(-) diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index e1b13a2..d044344 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -1,23 +1,22 @@ module ProximalAlgorithms -using AbstractDifferentiation: value_and_pullback_function +using AbstractDifferentiation using ProximalCore using ProximalCore: prox, prox! const RealOrComplex{R} = Union{R,Complex{R}} const Maybe{T} = Union{T,Nothing} -_ad_backend = nothing - -function ad_backend() - _ad_backend +struct AutoDifferentiable{F, B} + f::F + backend::B end -function ad_backend(backend) - global _ad_backend - _ad_backend = backend - _ad_backend -end +(f::AutoDifferentiable)(x) = f.f(x) + +value_and_pullback_function(f::AutoDifferentiable, x) = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x) + +value_and_pullback_function(f::ProximalCore.Zero, x) = f(x), _ -> (zero(x)) # various utilities diff --git a/src/algorithms/davis_yin.jl b/src/algorithms/davis_yin.jl index fea2af6..c6fc18e 100644 --- a/src/algorithms/davis_yin.jl +++ b/src/algorithms/davis_yin.jl @@ -55,7 +55,7 @@ end function Base.iterate(iter::DavisYinIteration) z = copy(iter.x0) xg, = prox(iter.g, z, iter.gamma) - f_xg, pb = value_and_pullback_function(ad_backend(), iter.f, xg) + f_xg, pb = value_and_pullback_function(iter.f, xg) grad_f_xg = pb(one(f_xg))[1] z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg xh, = prox(iter.h, z_half, iter.gamma) @@ -67,7 +67,7 @@ end function Base.iterate(iter::DavisYinIteration, state::DavisYinState) prox!(state.xg, iter.g, state.z, iter.gamma) - f_xg, pb = value_and_pullback_function(ad_backend(), iter.f, state.xg) + f_xg, pb = value_and_pullback_function(iter.f, state.xg) state.grad_f_xg .= pb(one(f_xg))[1] state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg prox!(state.xh, iter.h, state.z_half, iter.gamma) diff --git a/src/algorithms/fast_forward_backward.jl b/src/algorithms/fast_forward_backward.jl index 752642f..a6b4d7b 100644 --- a/src/algorithms/fast_forward_backward.jl +++ b/src/algorithms/fast_forward_backward.jl @@ -68,7 +68,7 @@ end function Base.iterate(iter::FastForwardBackwardIteration) x = copy(iter.x0) - f_x, pb = value_and_pullback_function(ad_backend(), iter.f, x) + f_x, pb = value_and_pullback_function(iter.f, x) grad_f_x = pb(one(f_x))[1] gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x @@ -104,7 +104,7 @@ function Base.iterate(iter::FastForwardBackwardIteration{R}, state::FastForwardB state.x .= state.z .+ beta .* (state.z .- state.z_prev) state.z_prev, state.z = state.z, state.z_prev - state.f_x, pb = value_and_pullback_function(ad_backend(), iter.f, state.x) + state.f_x, pb = value_and_pullback_function(iter.f, state.x) state.grad_f_x .= pb(one(state.f_x))[1] state.y .= state.x .- state.gamma .* state.grad_f_x state.g_z = prox!(state.z, iter.g, state.y, state.gamma) diff --git a/src/algorithms/forward_backward.jl b/src/algorithms/forward_backward.jl index 9516360..d321251 100644 --- a/src/algorithms/forward_backward.jl +++ b/src/algorithms/forward_backward.jl @@ -59,7 +59,7 @@ end function Base.iterate(iter::ForwardBackwardIteration) x = copy(iter.x0) - f_x, pb = value_and_pullback_function(ad_backend(), iter.f, x) + f_x, pb = value_and_pullback_function(iter.f, x) grad_f_x = pb(one(f_x))[1] gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x @@ -82,7 +82,7 @@ function Base.iterate(iter::ForwardBackwardIteration{R}, state::ForwardBackwardS state.grad_f_x, state.grad_f_z = state.grad_f_z, state.grad_f_x else state.x, state.z = state.z, state.x - state.f_x, pb = value_and_pullback_function(ad_backend(), iter.f, state.x) + state.f_x, pb = value_and_pullback_function(iter.f, state.x) state.grad_f_x .= pb(one(state.f_x))[1] end diff --git a/src/algorithms/li_lin.jl b/src/algorithms/li_lin.jl index ebba028..b147bbc 100644 --- a/src/algorithms/li_lin.jl +++ b/src/algorithms/li_lin.jl @@ -62,7 +62,7 @@ end function Base.iterate(iter::LiLinIteration{R}) where {R} y = copy(iter.x0) - f_y, pb = value_and_pullback_function(ad_backend(), iter.f, y) + f_y, pb = value_and_pullback_function(iter.f, y) grad_f_y = pb(one(f_y))[1] # TODO: initialize gamma if not provided @@ -103,7 +103,7 @@ function Base.iterate( else # TODO: re-use available space in state? # TODO: backtrack gamma at x - f_x, pb = value_and_pullback_function(ad_backend(), iter.f, x) + f_x, pb = value_and_pullback_function(iter.f, x) grad_f_x = pb(one(f_x))[1] x_forward = state.x - state.gamma .* grad_f_x v, g_v = prox(iter.g, x_forward, state.gamma) @@ -123,7 +123,7 @@ function Base.iterate( Fx = Fv end - state.f_y, pb = value_and_pullback_function(ad_backend(), iter.f, state.y) + state.f_y, pb = value_and_pullback_function(iter.f, state.y) state.grad_f_y .= pb(one(state.f_y))[1] state.y_forward .= state.y .- state.gamma .* state.grad_f_y state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma) diff --git a/src/algorithms/panoc.jl b/src/algorithms/panoc.jl index e4ff657..dfa20d0 100644 --- a/src/algorithms/panoc.jl +++ b/src/algorithms/panoc.jl @@ -86,7 +86,7 @@ f_model(iter::PANOCIteration, state::PANOCState) = f_model(state.f_Ax, state.At_ function Base.iterate(iter::PANOCIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, Ax) + f_Ax, pb = value_and_pullback_function(iter.f, Ax) grad_f_Ax = pb(one(f_Ax))[1] gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax @@ -153,7 +153,7 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where state.x_d .= state.x .+ state.d state.Ax_d .= state.Ax .+ state.Ad - state.f_Ax_d, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax_d) + state.f_Ax_d, pb = value_and_pullback_function(iter.f, state.Ax_d) state.grad_f_Ax_d .= pb(one(state.f_Ax_d))[1] mul!(state.At_grad_f_Ax_d, adjoint(iter.A), state.grad_f_Ax_d) @@ -191,7 +191,7 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where # along a line using interpolation and linear combinations # this allows saving operations if isinf(f_Az) - f_Az, pb = value_and_pullback_function(ad_backend(), iter.f, state.Az) + f_Az, pb = value_and_pullback_function(iter.f, state.Az) state.grad_f_Az .= pb(one(f_Az))[1] end if isinf(c) @@ -206,7 +206,7 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where else # otherwise, in the general case where f is only smooth, we compute # one gradient and matvec per backtracking step - state.f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax) + state.f_Ax, pb = value_and_pullback_function(iter.f, state.Ax) state.grad_f_Ax .= pb(one(state.f_Ax))[1] mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) end diff --git a/src/algorithms/panocplus.jl b/src/algorithms/panocplus.jl index 4031f26..26633da 100644 --- a/src/algorithms/panocplus.jl +++ b/src/algorithms/panocplus.jl @@ -79,7 +79,7 @@ f_model(iter::PANOCplusIteration, state::PANOCplusState) = f_model(state.f_Ax, s function Base.iterate(iter::PANOCplusIteration{R}) where {R} x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, Ax) + f_Ax, pb = value_and_pullback_function(iter.f, Ax) grad_f_Ax = pb(one(f_Ax))[1] gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax @@ -98,7 +98,7 @@ function Base.iterate(iter::PANOCplusIteration{R}) where {R} ) else mul!(state.Az, iter.A, state.z) - f_Az, pb = value_and_pullback_function(ad_backend(), iter.f, state.Az) + f_Az, pb = value_and_pullback_function(iter.f, state.Az) state.grad_f_Az = pb(one(f_Az))[1] end mul!(state.At_grad_f_Az, adjoint(iter.A), state.grad_f_Az) @@ -154,7 +154,7 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where end mul!(state.Ax, iter.A, state.x) - state.f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax) + state.f_Ax, pb = value_and_pullback_function(iter.f, state.Ax) state.grad_f_Ax .= pb(one(state.f_Ax))[1] mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) @@ -165,7 +165,7 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where f_Az_upp = f_model(iter, state) mul!(state.Az, iter.A, state.z) - f_Az, pb = value_and_pullback_function(ad_backend(), iter.f, state.Az) + f_Az, pb = value_and_pullback_function(iter.f, state.Az) state.grad_f_Az .= pb(one(f_Az))[1] if (iter.gamma === nothing || iter.adaptive == true) tol = 10 * eps(R) * (1 + abs(f_Az)) diff --git a/src/algorithms/primal_dual.jl b/src/algorithms/primal_dual.jl index b9f6fa4..3e83a90 100644 --- a/src/algorithms/primal_dual.jl +++ b/src/algorithms/primal_dual.jl @@ -167,7 +167,7 @@ end function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(iter.x0), y=copy(iter.y0))) # perform xbar-update step - f_x, pb = value_and_pullback_function(ad_backend(), iter.f, state.x) + f_x, pb = value_and_pullback_function(iter.f, state.x) state.gradf .= pb(one(f_x))[1] mul!(state.temp_x, iter.L', state.y) state.temp_x .+= state.gradf @@ -176,7 +176,7 @@ function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(i prox!(state.xbar, iter.g, state.temp_x, iter.gamma[1]) # perform ybar-update step - lc_y, pb = value_and_pullback_function(ad_backend(), convex_conjugate(iter.l), state.y) + lc_y, pb = value_and_pullback_function(convex_conjugate(iter.l), state.y) state.gradl .= pb(one(lc_y))[1] state.temp_x .= iter.theta .* state.xbar .+ (1 - iter.theta) .* state.x mul!(state.temp_y, iter.L, state.temp_x) diff --git a/src/algorithms/sfista.jl b/src/algorithms/sfista.jl index 55009f2..2967b1f 100644 --- a/src/algorithms/sfista.jl +++ b/src/algorithms/sfista.jl @@ -71,7 +71,7 @@ function Base.iterate( state.a = (state.τ + sqrt(state.τ ^ 2 + 4 * state.τ * state.APrev)) / 2 state.A = state.APrev + state.a state.xt .= (state.APrev / state.A) .* state.yPrev + (state.a / state.A) .* state.xPrev - f_xt, pb = value_and_pullback_function(ad_backend(), iter.f, state.xt) + f_xt, pb = value_and_pullback_function(iter.f, state.xt) state.gradf_xt .= pb(one(f_xt))[1] λ2 = state.λ / (1 + state.λ * iter.mf) # FISTA acceleration steps. @@ -94,7 +94,7 @@ function check_sc(state::SFISTAState, iter::SFISTAIteration, tol, termination_ty else # Classic (approximate) first-order stationary point [4]. The main inclusion is: r ∈ ∇f(y) + ∂h(y). λ2 = state.λ / (1 + state.λ * iter.mf) - f_y, pb = value_and_pullback_function(ad_backend(), iter.f, state.y) + f_y, pb = value_and_pullback_function(iter.f, state.y) gradf_y = pb(one(f_y))[1] r = gradf_y - state.gradf_xt + (state.xt - state.y) / λ2 res = norm(r) diff --git a/src/algorithms/zerofpr.jl b/src/algorithms/zerofpr.jl index 2240a48..da247a1 100644 --- a/src/algorithms/zerofpr.jl +++ b/src/algorithms/zerofpr.jl @@ -84,7 +84,7 @@ f_model(iter::ZeroFPRIteration, state::ZeroFPRState) = f_model(state.f_Ax, state function Base.iterate(iter::ZeroFPRIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, Ax) + f_Ax, pb = value_and_pullback_function(iter.f, Ax) grad_f_Ax = pb(one(f_Ax))[1] gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax @@ -131,7 +131,7 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R f_Axbar_upp, f_Axbar else mul!(state.Axbar, iter.A, state.xbar) - f_Axbar, pb = value_and_pullback_function(ad_backend(), iter.f, state.Axbar) + f_Axbar, pb = value_and_pullback_function(iter.f, state.Axbar) state.grad_f_Axbar .= pb(one(f_Axbar))[1] f_model(iter, state), f_Axbar end @@ -167,7 +167,7 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R state.x .= state.xbar_prev .+ state.tau .* state.d state.Ax .= state.Axbar .+ state.tau .* state.Ad # TODO: can precompute most of next line in case f is quadratic - state.f_Ax, pb = value_and_pullback_function(ad_backend(), iter.f, state.Ax) + state.f_Ax, pb = value_and_pullback_function(iter.f, state.Ax) state.grad_f_Ax .= pb(one(state.f_Ax))[1] mul!(state.At_grad_f_Ax, iter.A', state.grad_f_Ax) state.y .= state.x .- state.gamma .* state.At_grad_f_Ax diff --git a/src/utilities/fb_tools.jl b/src/utilities/fb_tools.jl index 83e9a78..8d31506 100644 --- a/src/utilities/fb_tools.jl +++ b/src/utilities/fb_tools.jl @@ -7,7 +7,7 @@ end function lower_bound_smoothness_constant(f, A, x, grad_f_Ax) R = real(eltype(x)) xeps = x .+ 1 - f_Axeps, pb = value_and_pullback_function(ad_backend(), f, A * xeps) + f_Axeps, pb = value_and_pullback_function(f, A * xeps) grad_f_Axeps = pb(one(R))[1] return norm(A' * (grad_f_Axeps - grad_f_Ax)) / R(sqrt(length(x))) end @@ -15,7 +15,7 @@ end function lower_bound_smoothness_constant(f, A, x) R = real(eltype(x)) Ax = A * x - f_Ax, pb = value_and_pullback_function(ad_backend(), f, Ax) + f_Ax, pb = value_and_pullback_function(f, Ax) grad_f_Ax = pb(one(R))[1] return lower_bound_smoothness_constant(f, A, x, grad_f_Ax) end @@ -29,7 +29,7 @@ function backtrack_stepsize!( ) where R f_Az_upp = f_model(f_Ax, At_grad_f_Ax, res, alpha / gamma) _mul!(Az, A, z) - f_Az, pb = value_and_pullback_function(ad_backend(), f, Az) + f_Az, pb = value_and_pullback_function(f, Az) if grad_f_Az !== nothing grad_f_Az .= pb(one(f_Az))[1] end @@ -41,7 +41,7 @@ function backtrack_stepsize!( res .= x .- z f_Az_upp = f_model(f_Ax, At_grad_f_Ax, res, alpha / gamma) _mul!(Az, A, z) - f_Az, pb = value_and_pullback_function(ad_backend(), f, Az) + f_Az, pb = value_and_pullback_function(f, Az) if grad_f_Az !== nothing grad_f_Az .= pb(one(f_Az))[1] end @@ -57,7 +57,7 @@ function backtrack_stepsize!( gamma, f, A, g, x; alpha = 1, minimum_gamma = 1e-7 ) Ax = A * x - f_Ax, pb = value_and_pullback_function(ad_backend(), f, Ax) + f_Ax, pb = value_and_pullback_function(f, Ax) grad_f_Ax = pb(one(f_Ax))[1] At_grad_f_Ax = A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax diff --git a/test/problems/test_elasticnet.jl b/test/problems/test_elasticnet.jl index 72e52b8..29e4052 100644 --- a/test/problems/test_elasticnet.jl +++ b/test/problems/test_elasticnet.jl @@ -1,8 +1,10 @@ -@testset "Elastic net ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64] - using ProximalOperators - using ProximalAlgorithms - using LinearAlgebra +using LinearAlgebra +using ProximalOperators: NormL1, SqrNormL2, ElasticNet, Translate +using ProximalAlgorithms +using Zygote +using AbstractDifferentiation: ZygoteBackend +@testset "Elastic net ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64] A = T[ 1.0 -2.0 3.0 -4.0 5.0 2.0 -1.0 0.0 -1.0 3.0 @@ -19,7 +21,7 @@ reg1 = NormL1(R(1)) reg2 = SqrNormL2(R(1)) loss = Translate(SqrNormL2(R(1)), -b) - cost = LeastSquares(A, b) + cost = ProximalAlgorithms.AutoDifferentiable(x -> (norm(A*x - b)^2) / 2, ZygoteBackend()) L = opnorm(A)^2 @@ -69,7 +71,7 @@ solver = ProximalAlgorithms.AFBA(theta = theta, mu = mu, tol = R(1e-6)) (x_afba, y_afba), it_afba = - solver(x0 = x0, y0 = y0, f = reg2, g = reg1, h = loss, L = A, beta_f = 1) + solver(x0 = x0, y0 = y0, f = ProximalAlgorithms.AutoDifferentiable(reg2, ZygoteBackend()), g = reg1, h = loss, L = A, beta_f = 1) @test eltype(x_afba) == T @test eltype(y_afba) == T @test norm(x_afba - x_star, Inf) <= 1e-4 @@ -86,7 +88,7 @@ solver = ProximalAlgorithms.AFBA(theta = theta, mu = mu, tol = R(1e-6)) (x_afba, y_afba), it_afba = - solver(x0 = x0, y0 = y0, f = reg2, g = reg1, h = loss, L = A, beta_f = 1) + solver(x0 = x0, y0 = y0, f = ProximalAlgorithms.AutoDifferentiable(reg2, ZygoteBackend()), g = reg1, h = loss, L = A, beta_f = 1) @test eltype(x_afba) == T @test eltype(y_afba) == T @test norm(x_afba - x_star, Inf) <= 1e-4 diff --git a/test/problems/test_equivalence.jl b/test/problems/test_equivalence.jl index 004d618..de8f970 100644 --- a/test/problems/test_equivalence.jl +++ b/test/problems/test_equivalence.jl @@ -1,7 +1,8 @@ using LinearAlgebra using Test - -using ProximalOperators +using Zygote +using AbstractDifferentiation: ZygoteBackend +using ProximalOperators: LeastSquares, NormL1 using ProximalAlgorithms: DouglasRachfordIteration, DRLSIteration, ForwardBackwardIteration, PANOCIteration, @@ -51,7 +52,7 @@ end lam = R(0.1) * norm(A' * b, Inf) - f = LeastSquares(A, b) + f = ProximalAlgorithms.AutoDifferentiable(x -> (norm(A*x - b)^2) / 2, ZygoteBackend()) g = NormL1(lam) x0 = zeros(R, n) @@ -79,7 +80,7 @@ end lam = R(0.1) * norm(A' * b, Inf) - f = LeastSquares(A, b) + f = ProximalAlgorithms.AutoDifferentiable(x -> (norm(A*x - b)^2) / 2, ZygoteBackend()) g = NormL1(lam) x0 = zeros(R, n) diff --git a/test/runtests.jl b/test/runtests.jl index 6532278..3a756b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,22 @@ using Test using Aqua +using AbstractDifferentiation using ProximalAlgorithms +struct CustomBackend end + +struct Quadratic{M, V} + Q::M + q::V +end + +(f::Quadratic)(x) = dot(x, f.Q * x) / 2 + dot(f.q, x) + +function ProximalAlgorithms.value_and_pullback_function(f::Quadratic, x) + grad = f.Q * x + f.q + return dot(grad, x) / 2 + dot(f.q, x), v -> (grad,) +end + @testset "Aqua" begin Aqua.test_all(ProximalAlgorithms; ambiguities=false) end diff --git a/test/utilities/test_ad.jl b/test/utilities/test_ad.jl index 1b1a47e..eac17dd 100644 --- a/test/utilities/test_ad.jl +++ b/test/utilities/test_ad.jl @@ -24,7 +24,7 @@ using AbstractDifferentiation: ZygoteBackend, ReverseDiffBackend, ForwardDiffBac -1.0 -1.0 -1.0 1.0 3.0 ] b = T[1.0, 2.0, 3.0, 4.0] - f = x -> R(1/2) * norm(A * x - b, 2)^2 + f = ProximalAlgorithms.AutoDifferentiable(x -> R(1/2) * norm(A * x - b, 2)^2, B()) Lf = opnorm(A)^2 m, n = size(A) @@ -33,8 +33,6 @@ using AbstractDifferentiation: ZygoteBackend, ReverseDiffBackend, ForwardDiffBac f_x0, pb = value_and_pullback_function(B(), f, x0) grad_f_x0 = @inferred pb(one(R))[1] - ProximalAlgorithms.ad_backend(B()) - lam = R(0.1) * norm(A' * b, Inf) @test typeof(lam) == R g = NormL1(lam) diff --git a/test/utilities/test_fb_tools.jl b/test/utilities/test_fb_tools.jl index 97c56ce..8205e08 100644 --- a/test/utilities/test_fb_tools.jl +++ b/test/utilities/test_fb_tools.jl @@ -1,39 +1,41 @@ using Test using LinearAlgebra using ProximalCore: Zero -using ProximalAlgorithms: lower_bound_smoothness_constant, backtrack_stepsize! +using ProximalAlgorithms +using AbstractDifferentiation @testset "Lipschitz constant estimation" for R in [Float32, Float64] -sv = R[0.01, 1.0, 1.0, 1.0, 100.0] -n = length(sv) -U, _ = qr(randn(n, n)) -Q = U * Diagonal(sv) * U' -q = randn(n) + sv = R[0.01, 1.0, 1.0, 1.0, 100.0] + n = length(sv) + U, _ = qr(randn(R, n, n)) + Q = U * Diagonal(sv) * U' + q = randn(R, n) -f(x) = R(0.5) * dot(x, Q * x) + dot(q, x) -Lf = maximum(sv) -g = Zero() + f = Quadratic(Q, q) + Lf = maximum(sv) + g = Zero() -for _ in 1:100 - x = randn(n) - Lest = @inferred lower_bound_smoothness_constant(f, I, x) - @test Lest <= Lf -end - -x = randn(n) -Lest = @inferred lower_bound_smoothness_constant(f, I, x) -alpha = R(0.5) -gamma_init = 10 / Lest -gamma = gamma_init + for _ in 1:100 + x = randn(R, n) + Lest = @inferred ProximalAlgorithms.lower_bound_smoothness_constant(f, I, x) + @test typeof(Lest) == R + @test Lest <= Lf + end -for _ in 1:100 x = randn(n) - new_gamma, = @inferred backtrack_stepsize!(gamma, f, I, g, x, alpha=alpha) - @test new_gamma <= gamma - gamma = new_gamma -end - -@test gamma < gamma_init + Lest = @inferred ProximalAlgorithms.lower_bound_smoothness_constant(f, I, x) + alpha = R(0.5) + gamma_init = 10 / Lest + gamma = gamma_init + + for _ in 1:100 + x = randn(n) + new_gamma, = @inferred ProximalAlgorithms.backtrack_stepsize!(gamma, f, I, g, x, alpha=alpha) + @test new_gamma <= gamma + gamma = new_gamma + end + + @test gamma < gamma_init end From 7cf1c1271cafcb93ac5cb51171f01475f925417c Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 3 Jan 2024 00:02:39 +0100 Subject: [PATCH 06/25] Fixup --- test/runtests.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 3a756b4..cd685b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,8 +3,6 @@ using Aqua using AbstractDifferentiation using ProximalAlgorithms -struct CustomBackend end - struct Quadratic{M, V} Q::M q::V From fc700bb5eb9bead2db944c6f1d440e1514708c03 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 3 Jan 2024 00:04:42 +0100 Subject: [PATCH 07/25] Fix --- test/utilities/test_ad.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/utilities/test_ad.jl b/test/utilities/test_ad.jl index eac17dd..a642542 100644 --- a/test/utilities/test_ad.jl +++ b/test/utilities/test_ad.jl @@ -5,7 +5,6 @@ using ProximalAlgorithms using Zygote using ReverseDiff using ForwardDiff -using AbstractDifferentiation: value_and_pullback_function using AbstractDifferentiation: ZygoteBackend, ReverseDiffBackend, ForwardDiffBackend @testset "Autodiff backend ($B on $T)" for (T, B) in Iterators.product( @@ -30,7 +29,7 @@ using AbstractDifferentiation: ZygoteBackend, ReverseDiffBackend, ForwardDiffBac x0 = zeros(T, n) - f_x0, pb = value_and_pullback_function(B(), f, x0) + f_x0, pb = ProximalAlgorithms.value_and_pullback_function(f, x0) grad_f_x0 = @inferred pb(one(R))[1] lam = R(0.1) * norm(A' * b, Inf) From 05f8334e6635fbb31608c9c075056d97e7b2380e Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 3 Jan 2024 11:44:23 +0100 Subject: [PATCH 08/25] fix more tests --- test/problems/test_lasso_small.jl | 44 ++++++++++--------- test/problems/test_lasso_small_h_split.jl | 38 ++++++++-------- .../test_lasso_small_strongly_convex.jl | 21 +++++---- test/problems/test_lasso_small_v_split.jl | 10 ++--- test/problems/test_linear_programs.jl | 10 +++-- test/problems/test_nonconvex_qp.jl | 8 ++-- test/problems/test_sparse_logistic_small.jl | 29 +++++++----- test/problems/test_verbose.jl | 37 +++++++++------- 8 files changed, 110 insertions(+), 87 deletions(-) diff --git a/test/problems/test_lasso_small.jl b/test/problems/test_lasso_small.jl index ba0d69d..2f1c387 100644 --- a/test/problems/test_lasso_small.jl +++ b/test/problems/test_lasso_small.jl @@ -1,7 +1,9 @@ using LinearAlgebra using Test -using ProximalOperators +using Zygote +using AbstractDifferentiation: ZygoteBackend +using ProximalOperators: NormL1, LeastSquares using ProximalAlgorithms using ProximalAlgorithms: LBFGS, Broyden, AndersonAcceleration, @@ -23,8 +25,10 @@ using ProximalAlgorithms: lam = R(0.1) * norm(A' * b, Inf) @test typeof(lam) == R - f = Translate(SqrNormL2(R(1)), -b) - f2 = LeastSquares(A, b) + f_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> (norm(x - b)^2)/2, ZygoteBackend()) + fA_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> (norm(A*x - b)^2)/2, ZygoteBackend()) + f_prox = Translate(SqrNormL2(R(1)), -b) + fA_prox = LeastSquares(A, b) g = NormL1(lam) Lf = opnorm(A)^2 @@ -37,7 +41,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.ForwardBackward(tol = TOL) - x, it = @inferred solver(x0 = x0, f = f2, g = g, Lf = Lf) + x, it = @inferred solver(x0 = x0, f = fA_autodiff, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 150 @@ -48,7 +52,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.ForwardBackward(tol = TOL, adaptive = true) - x, it = @inferred solver(x0 = x0, f = f2, g = g) + x, it = @inferred solver(x0 = x0, f = fA_autodiff, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 300 @@ -59,7 +63,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.FastForwardBackward(tol = TOL) - x, it = @inferred solver(x0 = x0, f = f2, g = g, Lf = Lf) + x, it = @inferred solver(x0 = x0, f = fA_autodiff, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 100 @@ -71,7 +75,7 @@ using ProximalAlgorithms: x0_backup = copy(x0) solver = ProximalAlgorithms.FastForwardBackward(tol = TOL, adaptive = true) - x, it = @inferred solver(x0 = x0, f = f2, g = g) + x, it = @inferred solver(x0 = x0, f = fA_autodiff, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 200 @@ -82,7 +86,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.FastForwardBackward(tol = TOL) - x, it = @inferred solver(x0 = x0, f = f2, g = g, Lf = Lf, extrapolation_sequence=FixedNesterovSequence(real(T))) + x, it = @inferred solver(x0 = x0, f = fA_autodiff, g = g, Lf = Lf, extrapolation_sequence=FixedNesterovSequence(real(T))) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 100 @@ -93,7 +97,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.ZeroFPR(tol = TOL) - x, it = @inferred solver(x0 = x0, f = f, A = A, g = g, Lf = Lf) + x, it = @inferred solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -104,7 +108,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.ZeroFPR(adaptive = true, tol = TOL) - x, it = @inferred solver(x0 = x0, f = f, A = A, g = g) + x, it = @inferred solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -116,7 +120,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.PANOC(tol = TOL) - x, it = @inferred solver(x0 = x0, f = f, A = A, g = g, Lf = Lf) + x, it = @inferred solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -128,7 +132,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.PANOC(adaptive = true, tol = TOL) - x, it = @inferred solver(x0 = x0, f = f, A = A, g = g) + x, it = @inferred solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -139,7 +143,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.PANOCplus(tol = TOL) - x, it = @inferred solver(x0 = x0, f = f, A = A, g = g, Lf = Lf) + x, it = @inferred solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -150,7 +154,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.PANOCplus(adaptive = true, tol = TOL) - x, it = @inferred solver(x0 = x0, f = f, A = A, g = g) + x, it = @inferred solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -161,7 +165,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.DouglasRachford(gamma = R(10) / opnorm(A)^2, tol = TOL) - y, it = @inferred solver(x0 = x0, f = f2, g = g) + y, it = @inferred solver(x0 = x0, f = fA_prox, g = g) @test eltype(y) == T @test norm(y - x_star, Inf) <= TOL @test it < 30 @@ -178,7 +182,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.DRLS(tol = 10 * TOL, directions=acc) - z, it = @inferred solver(x0 = x0, f = f2, g = g, Lf = Lf) + z, it = @inferred solver(x0 = x0, f = fA_prox, g = g, Lf = Lf) @test eltype(z) == T @test norm(z - x_star, Inf) <= 10 * TOL @test it < maxit @@ -189,7 +193,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.AFBA(theta = 1, mu = 1, tol = R(1e-6)) - (x_afba, y_afba), it_afba = @inferred solver(x0 = x0, y0 = zeros(T, n), f = f2, g = g, beta_f = opnorm(A)^2) + (x_afba, y_afba), it_afba = @inferred solver(x0 = x0, y0 = zeros(T, n), f = fA_autodiff, g = g, beta_f = opnorm(A)^2) @test eltype(x_afba) == T @test eltype(y_afba) == T @test norm(x_afba - x_star, Inf) <= 1e-4 @@ -197,7 +201,7 @@ using ProximalAlgorithms: @test x0 == x0_backup solver = ProximalAlgorithms.AFBA(theta = 1, mu = 1, tol = R(1e-6)) - (x_afba, y_afba), it_afba = @inferred solver(x0 = x0, y0 = zeros(T, n), f = f2, h = g, beta_f = opnorm(A)^2) + (x_afba, y_afba), it_afba = @inferred solver(x0 = x0, y0 = zeros(T, n), f = fA_autodiff, h = g, beta_f = opnorm(A)^2) @test eltype(x_afba) == T @test eltype(y_afba) == T @test norm(x_afba - x_star, Inf) <= 1e-4 @@ -205,7 +209,7 @@ using ProximalAlgorithms: @test x0 == x0_backup solver = ProximalAlgorithms.AFBA(theta = 1, mu = 1, tol = R(1e-6)) - (x_afba, y_afba), it_afba = @inferred solver(x0 = x0, y0 = zeros(T, m), h = f, L = A, g = g) + (x_afba, y_afba), it_afba = @inferred solver(x0 = x0, y0 = zeros(T, m), h = f_prox, L = A, g = g) @test eltype(x_afba) == T @test eltype(y_afba) == T @test norm(x_afba - x_star, Inf) <= 1e-4 @@ -217,7 +221,7 @@ using ProximalAlgorithms: x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.SFISTA(tol = 10 * TOL) - y, it = @inferred solver(x0 = x0, f = f2, g = g, Lf = Lf) + y, it = @inferred solver(x0 = x0, f = fA_autodiff, g = g, Lf = Lf) @test eltype(y) == T @test norm(y - x_star, Inf) <= 10 * TOL @test it < 100 diff --git a/test/problems/test_lasso_small_h_split.jl b/test/problems/test_lasso_small_h_split.jl index 54d0fcf..97de310 100644 --- a/test/problems/test_lasso_small_h_split.jl +++ b/test/problems/test_lasso_small_h_split.jl @@ -1,10 +1,12 @@ -@testset "Lasso small (h. split, $T)" for T in [Float32, Float64, ComplexF32, ComplexF64] - using ProximalOperators - using ProximalAlgorithms - using LinearAlgebra - using AbstractOperators: MatrixOp - using RecursiveArrayTools: ArrayPartition +using Zygote +using AbstractDifferentiation: ZygoteBackend +using ProximalOperators: NormL1, SeparableSum +using ProximalAlgorithms +using LinearAlgebra +using AbstractOperators: MatrixOp +using RecursiveArrayTools: ArrayPartition +@testset "Lasso small (h. split, $T)" for T in [Float32, Float64, ComplexF32, ComplexF64] A1 = T[ 1.0 -2.0 3.0 2.0 -1.0 0.0 @@ -29,8 +31,8 @@ lam = R(0.1) * norm(A' * b, Inf) @test typeof(lam) == R - f = Translate(SqrNormL2(R(1)), -b) - f2 = ComposeAffine(SqrNormL2(R(1)), A, -b) + f_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> (norm(x - b)^2)/2, ZygoteBackend()) + fA_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> (norm(A*x - b)^2)/2, ZygoteBackend()) g = SeparableSum(NormL1(lam), NormL1(lam)) Lf = opnorm([A1 A2])^2 @@ -49,7 +51,7 @@ x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) x0_backup = copy(x0) solver = ProximalAlgorithms.ForwardBackward(tol = TOL) - x, it = solver(x0 = x0, f = f2, g = g, Lf = Lf) + x, it = solver(x0 = x0, f = fA_autodiff, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 150 @@ -60,7 +62,7 @@ x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) x0_backup = copy(x0) solver = ProximalAlgorithms.ForwardBackward(tol = TOL, adaptive = true) - x, it = solver(x0 = x0, f = f2, g = g) + x, it = solver(x0 = x0, f = fA_autodiff, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 300 @@ -71,7 +73,7 @@ x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) x0_backup = copy(x0) solver = ProximalAlgorithms.FastForwardBackward(tol = TOL) - x, it = solver(x0 = x0, f = f2, g = g, Lf = Lf) + x, it = solver(x0 = x0, f = fA_autodiff, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 100 @@ -83,7 +85,7 @@ x0_backup = copy(x0) solver = ProximalAlgorithms.FastForwardBackward(tol = TOL, adaptive = true) - x, it = solver(x0 = x0, f = f2, g = g) + x, it = solver(x0 = x0, f = fA_autodiff, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 200 @@ -97,7 +99,7 @@ x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) x0_backup = copy(x0) solver = ProximalAlgorithms.ZeroFPR(tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g, Lf = Lf) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -108,7 +110,7 @@ x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) x0_backup = copy(x0) solver = ProximalAlgorithms.ZeroFPR(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -123,7 +125,7 @@ x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) x0_backup = copy(x0) solver = ProximalAlgorithms.PANOC(tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g, Lf = Lf) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -134,7 +136,7 @@ x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) x0_backup = copy(x0) solver = ProximalAlgorithms.PANOC(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -149,7 +151,7 @@ x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) x0_backup = copy(x0) solver = ProximalAlgorithms.PANOCplus(tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g, Lf = Lf) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -160,7 +162,7 @@ x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) x0_backup = copy(x0) solver = ProximalAlgorithms.PANOCplus(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 diff --git a/test/problems/test_lasso_small_strongly_convex.jl b/test/problems/test_lasso_small_strongly_convex.jl index 7dd959b..d2e7e13 100644 --- a/test/problems/test_lasso_small_strongly_convex.jl +++ b/test/problems/test_lasso_small_strongly_convex.jl @@ -1,7 +1,9 @@ using LinearAlgebra using Test -using ProximalOperators +using Zygote +using AbstractDifferentiation: ZygoteBackend +using ProximalOperators: NormL1, LeastSquares using ProximalAlgorithms @testset "Lasso small (strongly convex, $T)" for T in [Float32, Float64] @@ -29,7 +31,8 @@ using ProximalAlgorithms A = Q * D * Q' b = A * x_star + lam * inv(A') * sign.(x_star) - f = LeastSquares(A, b) + fA_prox = LeastSquares(A, b) + fA_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> (norm(A*x - b)^2)/2, ZygoteBackend()) g = NormL1(lam) TOL = T(1e-4) @@ -39,7 +42,7 @@ using ProximalAlgorithms @testset "SFISTA" begin solver = ProximalAlgorithms.SFISTA(tol = TOL) - y, it = solver(x0=x0, f=f, g=g, Lf=Lf, mf=mf) + y, it = solver(x0=x0, f=fA_autodiff, g=g, Lf=Lf, mf=mf) @test eltype(y) == T @test norm(y - x_star) <= TOL @test it < 40 @@ -48,7 +51,7 @@ using ProximalAlgorithms @testset "ForwardBackward" begin solver = ProximalAlgorithms.ForwardBackward(tol = TOL) - y, it = solver(x0=x0, f=f, g=g, Lf=Lf) + y, it = solver(x0=x0, f=fA_autodiff, g=g, Lf=Lf) @test eltype(y) == T @test norm(y - x_star, Inf) <= TOL @test it < 110 @@ -57,7 +60,7 @@ using ProximalAlgorithms @testset "FastForwardBackward" begin solver = ProximalAlgorithms.FastForwardBackward(tol = TOL) - y, it = solver(x0=x0, f=f, g=g, Lf=Lf, mf=mf) + y, it = solver(x0=x0, f=fA_autodiff, g=g, Lf=Lf, mf=mf) @test eltype(y) == T @test norm(y - x_star, Inf) <= TOL @test it < 35 @@ -66,7 +69,7 @@ using ProximalAlgorithms @testset "FastForwardBackward (custom extrapolation)" begin solver = ProximalAlgorithms.FastForwardBackward(tol = TOL) - y, it = solver(x0=x0, f=f, g=g, gamma = 1/Lf, mf=mf, extrapolation_sequence=ProximalAlgorithms.ConstantNesterovSequence(mf, 1/Lf)) + y, it = solver(x0=x0, f=fA_autodiff, g=g, gamma = 1/Lf, mf=mf, extrapolation_sequence=ProximalAlgorithms.ConstantNesterovSequence(mf, 1/Lf)) @test eltype(y) == T @test norm(y - x_star, Inf) <= TOL @test it < 35 @@ -75,7 +78,7 @@ using ProximalAlgorithms @testset "DRLS" begin solver = ProximalAlgorithms.DRLS(tol = TOL) - v, it = solver(x0=x0, f=f, g=g, mf=mf) + v, it = solver(x0=x0, f=fA_prox, g=g, mf=mf) @test eltype(v) == T @test norm(v - x_star, Inf) <= TOL @test it < 14 @@ -84,7 +87,7 @@ using ProximalAlgorithms @testset "PANOC" begin solver = ProximalAlgorithms.PANOC(tol = TOL) - y, it = solver(x0=x0, f=f, g=g, Lf=Lf) + y, it = solver(x0=x0, f=fA_autodiff, g=g, Lf=Lf) @test eltype(y) == T @test norm(y - x_star, Inf) <= TOL @test it < 45 @@ -93,7 +96,7 @@ using ProximalAlgorithms @testset "PANOCplus" begin solver = ProximalAlgorithms.PANOCplus(tol = TOL) - y, it = solver(x0=x0, f=f, g=g, Lf=Lf) + y, it = solver(x0=x0, f=fA_autodiff, g=g, Lf=Lf) @test eltype(y) == T @test norm(y - x_star, Inf) <= TOL @test it < 45 diff --git a/test/problems/test_lasso_small_v_split.jl b/test/problems/test_lasso_small_v_split.jl index 855df8b..80c8230 100644 --- a/test/problems/test_lasso_small_v_split.jl +++ b/test/problems/test_lasso_small_v_split.jl @@ -1,9 +1,9 @@ -@testset "Lasso small (v. split, $T)" for T in [Float32, Float64, ComplexF32, ComplexF64] - using ProximalOperators - using ProximalAlgorithms - using LinearAlgebra - using AbstractOperators: MatrixOp +using ProximalOperators: LeastSquares, NormL1, SeparableSum, Sum, Translate +using ProximalAlgorithms +using LinearAlgebra +using AbstractOperators: MatrixOp +@testset "Lasso small (v. split, $T)" for T in [Float32, Float64, ComplexF32, ComplexF64] A1 = T[ 1.0 -2.0 3.0 -4.0 5.0 2.0 -1.0 0.0 -1.0 3.0 diff --git a/test/problems/test_linear_programs.jl b/test/problems/test_linear_programs.jl index 51e69c6..30f82e5 100644 --- a/test/problems/test_linear_programs.jl +++ b/test/problems/test_linear_programs.jl @@ -1,4 +1,6 @@ -using ProximalOperators +using Zygote +using AbstractDifferentiation: ZygoteBackend +using ProximalOperators: Linear, IndNonnegative, IndPoint, IndAffine, SlicedSeparableSum using ProximalAlgorithms using LinearAlgebra @@ -68,7 +70,7 @@ end @testset "AFBA" begin - f = Linear(c) + f = ProximalAlgorithms.AutoDifferentiable(x -> dot(c, x), ZygoteBackend()) g = IndNonnegative() h = IndPoint(b) @@ -94,7 +96,7 @@ end @testset "VuCondat" begin - f = Linear(c) + f = ProximalAlgorithms.AutoDifferentiable(x -> dot(c, x), ZygoteBackend()) g = IndNonnegative() h = IndPoint(b) @@ -143,7 +145,7 @@ end @testset "DavisYin" begin - f = Linear(c) + f = ProximalAlgorithms.AutoDifferentiable(x -> dot(c, x), ZygoteBackend()) g = IndNonnegative() h = IndAffine(A, b) diff --git a/test/problems/test_nonconvex_qp.jl b/test/problems/test_nonconvex_qp.jl index dd3d2e5..6f1dfca 100644 --- a/test/problems/test_nonconvex_qp.jl +++ b/test/problems/test_nonconvex_qp.jl @@ -1,5 +1,7 @@ +using Zygote +using AbstractDifferentiation: ZygoteBackend using ProximalAlgorithms -using ProximalOperators +using ProximalOperators: IndBox using LinearAlgebra using Random using Test @@ -10,7 +12,7 @@ using Test low = T(-1.0) upp = T(+1.0) - f = Quadratic(Q, q) + f = ProximalAlgorithms.AutoDifferentiable(x -> dot(Q * x, x) / 2 + dot(q, x), ZygoteBackend()) g = IndBox(low, upp) n = 2 @@ -76,7 +78,7 @@ end low = T(-1.0) upp = T(+1.0) - f = Quadratic(Q, q) + f = ProximalAlgorithms.AutoDifferentiable(x -> dot(Q * x, x) / 2 + dot(q, x), ZygoteBackend()) g = IndBox(low, upp) Lip = maximum(abs.(eigenvalues)) diff --git a/test/problems/test_sparse_logistic_small.jl b/test/problems/test_sparse_logistic_small.jl index 47a26d0..19b81a7 100644 --- a/test/problems/test_sparse_logistic_small.jl +++ b/test/problems/test_sparse_logistic_small.jl @@ -1,8 +1,10 @@ -@testset "Sparse logistic small ($T)" for T in [Float32, Float64] - using ProximalOperators - using ProximalAlgorithms - using LinearAlgebra +using Zygote +using AbstractDifferentiation: ZygoteBackend +using ProximalOperators: NormL1 +using ProximalAlgorithms +using LinearAlgebra +@testset "Sparse logistic small ($T)" for T in [Float32, Float64] A = T[ 1.0 -2.0 3.0 -4.0 5.0 2.0 -1.0 0.0 -1.0 3.0 @@ -15,8 +17,13 @@ R = real(T) - f = Translate(LogisticLoss(ones(R, m), R(1)), -b) - f2 = ComposeAffine(LogisticLoss(ones(R, m), R(1)), A, -b) + function logistic_loss(logits) + u = 1 .+ exp.(-logits) # labels are assumed all one + return sum(log.(u)) + end + + f_autodiff = ProximalAlgorithms.AutoDifferentiable(logistic_loss, ZygoteBackend()) + fA_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> logistic_loss(A * x - b), ZygoteBackend()) lam = R(0.1) g = NormL1(lam) @@ -29,7 +36,7 @@ x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.ForwardBackward(tol = TOL, adaptive = true) - x, it = solver(x0 = x0, f = f2, g = g) + x, it = solver(x0 = x0, f = fA_autodiff, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= 1e-4 @test it < 1100 @@ -40,7 +47,7 @@ x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.FastForwardBackward(tol = TOL, adaptive = true) - x, it = solver(x0 = x0, f = f2, g = g) + x, it = solver(x0 = x0, f = fA_autodiff, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= 1e-4 @test it < 500 @@ -51,7 +58,7 @@ x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.ZeroFPR(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= 1e-4 @test it < 25 @@ -62,7 +69,7 @@ x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.PANOC(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= 1e-4 @test it < 50 @@ -73,7 +80,7 @@ x0 = zeros(T, n) x0_backup = copy(x0) solver = ProximalAlgorithms.PANOCplus(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= 1e-4 @test it < 50 diff --git a/test/problems/test_verbose.jl b/test/problems/test_verbose.jl index ff942c1..573d7fe 100644 --- a/test/problems/test_verbose.jl +++ b/test/problems/test_verbose.jl @@ -1,8 +1,10 @@ -@testset "Verbose" for T in [Float64] - using ProximalOperators - using ProximalAlgorithms - using LinearAlgebra +using Zygote +using AbstractDifferentiation: ZygoteBackend +using ProximalOperators: LeastSquares, NormL1 +using ProximalAlgorithms +using LinearAlgebra +@testset "Verbose" for T in [Float64] A = T[ 1.0 -2.0 3.0 -4.0 5.0 2.0 -1.0 0.0 -1.0 3.0 @@ -18,8 +20,9 @@ lam = R(0.1) * norm(A' * b, Inf) @test typeof(lam) == R - f = Translate(SqrNormL2(R(1)), -b) - f2 = LeastSquares(A, b) + f_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> (norm(x - b)^2)/2, ZygoteBackend()) + fA_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> (norm(A*x - b)^2)/2, ZygoteBackend()) + fA_prox = LeastSquares(A, b) g = NormL1(lam) Lf = opnorm(A)^2 @@ -34,7 +37,7 @@ x0 = zeros(T, n) solver = ProximalAlgorithms.ForwardBackward(tol = TOL, verbose = true) - x, it = solver(x0 = x0, f = f2, g = g, Lf = Lf) + x, it = solver(x0 = x0, f = fA_autodiff, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 150 @@ -47,7 +50,7 @@ adaptive = true, verbose = true, ) - x, it = solver(x0 = x0, f = f2, g = g) + x, it = solver(x0 = x0, f = fA_autodiff, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 300 @@ -57,7 +60,7 @@ x0 = zeros(T, n) solver = ProximalAlgorithms.FastForwardBackward(tol = TOL, verbose = true) - x, it = solver(x0 = x0, f = f2, g = g, Lf = Lf) + x, it = solver(x0 = x0, f = fA_autodiff, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 100 @@ -70,7 +73,7 @@ adaptive = true, verbose = true, ) - x, it = solver(x0 = x0, f = f2, g = g) + x, it = solver(x0 = x0, f = fA_autodiff, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 200 @@ -82,7 +85,7 @@ x0 = zeros(T, n) solver = ProximalAlgorithms.ZeroFPR(tol = TOL, verbose = true) - x, it = solver(x0 = x0, f = f, A = A, g = g, Lf = Lf) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -91,7 +94,7 @@ x0 = zeros(T, n) solver = ProximalAlgorithms.ZeroFPR(adaptive = true, tol = TOL, verbose = true) - x, it = solver(x0 = x0, f = f, A = A, g = g) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -104,7 +107,7 @@ x0 = zeros(T, n) solver = ProximalAlgorithms.PANOC(tol = TOL, verbose = true) - x, it = solver(x0 = x0, f = f, A = A, g = g, Lf = Lf) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -113,7 +116,7 @@ x0 = zeros(T, n) solver = ProximalAlgorithms.PANOC(adaptive = true, tol = TOL, verbose = true) - x, it = solver(x0 = x0, f = f, A = A, g = g) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -126,7 +129,7 @@ x0 = zeros(T, n) solver = ProximalAlgorithms.PANOCplus(tol = TOL, verbose = true) - x, it = solver(x0 = x0, f = f, A = A, g = g, Lf = Lf) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -135,7 +138,7 @@ x0 = zeros(T, n) solver = ProximalAlgorithms.PANOCplus(adaptive = true, tol = TOL, verbose = true) - x, it = solver(x0 = x0, f = f, A = A, g = g) + x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) @test eltype(x) == T @test norm(x - x_star, Inf) <= TOL @test it < 20 @@ -152,7 +155,7 @@ tol = TOL, verbose = true, ) - y, it = solver(x0 = x0, f = f2, g = g) + y, it = solver(x0 = x0, f = fA_prox, g = g) @test eltype(y) == T @test norm(y - x_star, Inf) <= TOL @test it < 30 From c56cc083c95d6e90e021b4ff681d7fb28f6985a7 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 3 Jan 2024 13:37:28 +0100 Subject: [PATCH 09/25] remove test and unused code --- test/definitions/arraypartition.jl | 20 --- test/definitions/compose.jl | 29 ---- test/problems/test_lasso_small_h_split.jl | 173 -------------------- test/problems/test_lasso_small_v_split.jl | 164 ------------------- test/problems/test_sparse_logistic_small.jl | 2 +- test/runtests.jl | 5 - 6 files changed, 1 insertion(+), 392 deletions(-) delete mode 100644 test/definitions/arraypartition.jl delete mode 100644 test/definitions/compose.jl delete mode 100644 test/problems/test_lasso_small_h_split.jl delete mode 100644 test/problems/test_lasso_small_v_split.jl diff --git a/test/definitions/arraypartition.jl b/test/definitions/arraypartition.jl deleted file mode 100644 index 87dcb78..0000000 --- a/test/definitions/arraypartition.jl +++ /dev/null @@ -1,20 +0,0 @@ -import ProximalCore -import RecursiveArrayTools - -@inline function ProximalCore.prox(h, x::RecursiveArrayTools.ArrayPartition, gamma...) - # unwrap - y, fy = ProximalCore.prox(h, x.x, gamma...) - # wrap - return RecursiveArrayTools.ArrayPartition(y), fy -end - -@inline function ProximalCore.gradient(h, x::RecursiveArrayTools.ArrayPartition) - # unwrap - grad, fx = ProximalCore.gradient(h, x.x) - # wrap - return RecursiveArrayTools.ArrayPartition(grad), fx -end - -@inline ProximalCore.prox!(y::RecursiveArrayTools.ArrayPartition, h, x::RecursiveArrayTools.ArrayPartition, gamma...) = ProximalCore.prox!(y.x, h, x.x, gamma...) - -@inline ProximalCore.gradient!(y::RecursiveArrayTools.ArrayPartition, h, x::RecursiveArrayTools.ArrayPartition) = ProximalCore.gradient!(y.x, h, x.x) diff --git a/test/definitions/compose.jl b/test/definitions/compose.jl deleted file mode 100644 index 8e38c2a..0000000 --- a/test/definitions/compose.jl +++ /dev/null @@ -1,29 +0,0 @@ -using RecursiveArrayTools: ArrayPartition -using ProximalCore - -struct ComposeAffine - f - A - b -end - -function (g::ComposeAffine)(x) - res = g.A * x .+ g.b - return g.f(res) -end - -function compose_affine_gradient!(y, g::ComposeAffine, x) - res = g.A * x .+ g.b - gradres, v = gradient(g.f, res) - mul!(y, adjoint(g.A), gradres) - return v -end - -ProximalCore.gradient!(y, g::ComposeAffine, x) = compose_affine_gradient!(y, g, x) -ProximalCore.gradient!(y::ArrayPartition, g::ComposeAffine, x::ArrayPartition) = compose_affine_gradient!(y, g, x) - -function ProximalCore.gradient(h::ComposeAffine, x::ArrayPartition) - grad_fx = similar(x) - fx = ProximalCore.gradient!(grad_fx, h, x) - return grad_fx, fx -end diff --git a/test/problems/test_lasso_small_h_split.jl b/test/problems/test_lasso_small_h_split.jl deleted file mode 100644 index 97de310..0000000 --- a/test/problems/test_lasso_small_h_split.jl +++ /dev/null @@ -1,173 +0,0 @@ -using Zygote -using AbstractDifferentiation: ZygoteBackend -using ProximalOperators: NormL1, SeparableSum -using ProximalAlgorithms -using LinearAlgebra -using AbstractOperators: MatrixOp -using RecursiveArrayTools: ArrayPartition - -@testset "Lasso small (h. split, $T)" for T in [Float32, Float64, ComplexF32, ComplexF64] - A1 = T[ - 1.0 -2.0 3.0 - 2.0 -1.0 0.0 - -1.0 0.0 4.0 - -1.0 -1.0 -1.0 - ] - A2 = T[ - -4.0 5.0 - -1.0 3.0 - -3.0 2.0 - 1.0 3.0 - ] - A = hcat(MatrixOp(A1), MatrixOp(A2)) - b = T[1.0, 2.0, 3.0, 4.0] - - m, n1 = size(A1) - _, n2 = size(A2) - n = n1 + n2 - - R = real(T) - - lam = R(0.1) * norm(A' * b, Inf) - @test typeof(lam) == R - - f_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> (norm(x - b)^2)/2, ZygoteBackend()) - fA_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> (norm(A*x - b)^2)/2, ZygoteBackend()) - g = SeparableSum(NormL1(lam), NormL1(lam)) - - Lf = opnorm([A1 A2])^2 - - x_star = ArrayPartition( - T[-3.877278911564627e-01, 0, 0], - T[2.174149659863943e-02, 6.168435374149660e-01], - ) - - TOL = R(1e-4) - - @testset "ForwardBackward" begin - - ## Nonfast/Nonadaptive - - x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) - x0_backup = copy(x0) - solver = ProximalAlgorithms.ForwardBackward(tol = TOL) - x, it = solver(x0 = x0, f = fA_autodiff, g = g, Lf = Lf) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 150 - @test x0 == x0_backup - - # Nonfast/Adaptive - - x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) - x0_backup = copy(x0) - solver = ProximalAlgorithms.ForwardBackward(tol = TOL, adaptive = true) - x, it = solver(x0 = x0, f = fA_autodiff, g = g) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 300 - @test x0 == x0_backup - - # Fast/Nonadaptive - - x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) - x0_backup = copy(x0) - solver = ProximalAlgorithms.FastForwardBackward(tol = TOL) - x, it = solver(x0 = x0, f = fA_autodiff, g = g, Lf = Lf) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 100 - @test x0 == x0_backup - - # Fast/Adaptive - - x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) - x0_backup = copy(x0) - solver = - ProximalAlgorithms.FastForwardBackward(tol = TOL, adaptive = true) - x, it = solver(x0 = x0, f = fA_autodiff, g = g) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 200 - @test x0 == x0_backup - end - - @testset "ZeroFPR" begin - - # ZeroFPR/Nonadaptive - - x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) - x0_backup = copy(x0) - solver = ProximalAlgorithms.ZeroFPR(tol = TOL) - x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - # ZeroFPR/Adaptive - - x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) - x0_backup = copy(x0) - solver = ProximalAlgorithms.ZeroFPR(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - end - - @testset "PANOC" begin - - # PANOC/Nonadaptive - - x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) - x0_backup = copy(x0) - solver = ProximalAlgorithms.PANOC(tol = TOL) - x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - ## PANOC/Adaptive - - x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) - x0_backup = copy(x0) - solver = ProximalAlgorithms.PANOC(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - end - - @testset "PANOCplus" begin - - # PANOCplus/Nonadaptive - - x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) - x0_backup = copy(x0) - solver = ProximalAlgorithms.PANOCplus(tol = TOL) - x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g, Lf = Lf) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - ## PANOCplus/Adaptive - - x0 = ArrayPartition(zeros(T, n1), zeros(T, n2)) - x0_backup = copy(x0) - solver = ProximalAlgorithms.PANOCplus(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f_autodiff, A = A, g = g) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - end - -end diff --git a/test/problems/test_lasso_small_v_split.jl b/test/problems/test_lasso_small_v_split.jl deleted file mode 100644 index 80c8230..0000000 --- a/test/problems/test_lasso_small_v_split.jl +++ /dev/null @@ -1,164 +0,0 @@ -using ProximalOperators: LeastSquares, NormL1, SeparableSum, Sum, Translate -using ProximalAlgorithms -using LinearAlgebra -using AbstractOperators: MatrixOp - -@testset "Lasso small (v. split, $T)" for T in [Float32, Float64, ComplexF32, ComplexF64] - A1 = T[ - 1.0 -2.0 3.0 -4.0 5.0 - 2.0 -1.0 0.0 -1.0 3.0 - ] - A2 = T[ - -1.0 0.0 4.0 -3.0 2.0 - -1.0 -1.0 -1.0 1.0 3.0 - ] - A = vcat(MatrixOp(A1), MatrixOp(A2)) - b1 = T[1.0, 2.0] - b2 = T[3.0, 4.0] - - m1, n = size(A1) - m2, _ = size(A2) - m = m1 + m2 - - R = real(T) - - lam = R(0.1) * norm([A1; A2]' * [b1; b2], Inf) - @test typeof(lam) == R - - f = SeparableSum(Translate(SqrNormL2(R(1)), -b1), Translate(SqrNormL2(R(1)), -b2)) - f2 = Sum(LeastSquares(A1, b1), LeastSquares(A2, b2)) - g = NormL1(lam) - - Lf = opnorm([A1; A2])^2 - - x_star = T[-3.877278911564627e-01, 0, 0, 2.174149659863943e-02, 6.168435374149660e-01] - - TOL = R(1e-4) - - @testset "ForwardBackward" begin - - ## Nonfast/Nonadaptive - - x0 = zeros(T, n) - x0_backup = copy(x0) - solver = ProximalAlgorithms.ForwardBackward(tol = TOL) - x, it = solver(x0 = x0, f = f2, g = g, Lf = Lf) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 150 - @test x0 == x0_backup - - # Nonfast/Adaptive - - x0 = zeros(T, n) - x0_backup = copy(x0) - solver = ProximalAlgorithms.ForwardBackward(tol = TOL, adaptive = true) - x, it = solver(x0 = x0, f = f2, g = g) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 300 - @test x0 == x0_backup - - # Fast/Nonadaptive - - x0 = zeros(T, n) - x0_backup = copy(x0) - solver = ProximalAlgorithms.FastForwardBackward(tol = TOL) - x, it = solver(x0 = x0, f = f2, g = g, Lf = Lf) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 100 - @test x0 == x0_backup - - # Fast/Adaptive - - x0 = zeros(T, n) - x0_backup = copy(x0) - solver = - ProximalAlgorithms.FastForwardBackward(tol = TOL, adaptive = true) - x, it = solver(x0 = x0, f = f2, g = g) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 200 - @test x0 == x0_backup - end - - @testset "ZeroFPR" begin - - # ZeroFPR/Nonadaptive - - x0 = zeros(T, n) - x0_backup = copy(x0) - solver = ProximalAlgorithms.ZeroFPR(tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g, Lf = opnorm([A1; A2])^2) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - # ZeroFPR/Adaptive - - x0 = zeros(T, n) - x0_backup = copy(x0) - solver = ProximalAlgorithms.ZeroFPR(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - end - - @testset "PANOC" begin - - # PANOC/Nonadaptive - - x0 = zeros(T, n) - x0_backup = copy(x0) - solver = ProximalAlgorithms.PANOC(tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g, Lf = opnorm([A1; A2])^2) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - ## PANOC/Adaptive - - x0 = zeros(T, n) - x0_backup = copy(x0) - solver = ProximalAlgorithms.PANOC(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - end - - @testset "PANOCplus" begin - - # PANOCplus/Nonadaptive - - x0 = zeros(T, n) - x0_backup = copy(x0) - solver = ProximalAlgorithms.PANOCplus(tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g, Lf = opnorm([A1; A2])^2) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - ## PANOCplus/Adaptive - - x0 = zeros(T, n) - x0_backup = copy(x0) - solver = ProximalAlgorithms.PANOCplus(adaptive = true, tol = TOL) - x, it = solver(x0 = x0, f = f, A = A, g = g) - @test eltype(x) == T - @test norm(x - x_star, Inf) <= TOL - @test it < 20 - @test x0 == x0_backup - - end - -end diff --git a/test/problems/test_sparse_logistic_small.jl b/test/problems/test_sparse_logistic_small.jl index 19b81a7..135fee5 100644 --- a/test/problems/test_sparse_logistic_small.jl +++ b/test/problems/test_sparse_logistic_small.jl @@ -22,7 +22,7 @@ using LinearAlgebra return sum(log.(u)) end - f_autodiff = ProximalAlgorithms.AutoDifferentiable(logistic_loss, ZygoteBackend()) + f_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> logistic_loss(x - b), ZygoteBackend()) fA_autodiff = ProximalAlgorithms.AutoDifferentiable(x -> logistic_loss(A * x - b), ZygoteBackend()) lam = R(0.1) g = NormL1(lam) diff --git a/test/runtests.jl b/test/runtests.jl index cd685b9..00a5e0a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,9 +19,6 @@ end Aqua.test_all(ProximalAlgorithms; ambiguities=false) end -include("definitions/arraypartition.jl") -include("definitions/compose.jl") - include("utilities/test_ad.jl") include("utilities/test_iteration_tools.jl") include("utilities/test_fb_tools.jl") @@ -35,8 +32,6 @@ include("problems/test_equivalence.jl") include("problems/test_elasticnet.jl") include("problems/test_lasso_small.jl") include("problems/test_lasso_small_strongly_convex.jl") -include("problems/test_lasso_small_v_split.jl") -include("problems/test_lasso_small_h_split.jl") include("problems/test_linear_programs.jl") include("problems/test_sparse_logistic_small.jl") include("problems/test_nonconvex_qp.jl") From ea5be88d2783f55b6b8d74cd0686ec74fb4caf26 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 3 Jan 2024 13:38:43 +0100 Subject: [PATCH 10/25] unused dependency --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 7241417..648b6b5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,7 +8,6 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" From 6b94ce81209699e527db99df83dcd481727d5b56 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 3 Jan 2024 14:04:42 +0100 Subject: [PATCH 11/25] update docs snippets --- docs/Project.toml | 2 ++ docs/src/examples/sparse_linear_regression.jl | 7 +++-- docs/src/guide/custom_objectives.jl | 26 +++++++++++++------ docs/src/guide/getting_started.jl | 10 ++++--- 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 708dd2b..f38368c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" @@ -7,6 +8,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Documenter = "1" diff --git a/docs/src/examples/sparse_linear_regression.jl b/docs/src/examples/sparse_linear_regression.jl index 9664e3b..3552a6d 100644 --- a/docs/src/examples/sparse_linear_regression.jl +++ b/docs/src/examples/sparse_linear_regression.jl @@ -51,10 +51,13 @@ end mean_squared_error(label, output) = mean((output .- label) .^ 2) / 2 +using Zygote +using AbstractDifferentiation: ZygoteBackend using ProximalAlgorithms -training_loss = ProximalAlgorithms.ZygoteFunction( - wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input)) +training_loss = ProximalAlgorithms.AutoDifferentiable( + wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input)), + ZygoteBackend() ) # As regularization we will use the L1 norm, implemented in [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl): diff --git a/docs/src/guide/custom_objectives.jl b/docs/src/guide/custom_objectives.jl index 56fc3c7..c3cb93f 100644 --- a/docs/src/guide/custom_objectives.jl +++ b/docs/src/guide/custom_objectives.jl @@ -31,10 +31,13 @@ # # Let's try to minimize the celebrated Rosenbrock function, but constrained to the unit norm ball. The cost function is +using Zygote +using AbstractDifferentiation: ZygoteBackend using ProximalAlgorithms -rosenbrock2D = ProximalAlgorithms.ZygoteFunction( - x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2 +rosenbrock2D = ProximalAlgorithms.AutoDifferentiable( + x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2, + ZygoteBackend() ) # To enforce the constraint, we define the indicator of the unit ball, together with its proximal mapping: @@ -82,17 +85,23 @@ scatter!([solution[1]], [solution[2]], color=:red, markershape=:star5, label="co mutable struct Counting{T} f::T + eval_count::Int gradient_count::Int prox_count::Int end -Counting(f::T) where T = Counting{T}(f, 0, 0) +Counting(f::T) where T = Counting{T}(f, 0, 0, 0) # Now we only need to intercept any call to `gradient!` and `prox!` and increase counters there: -function ProximalCore.gradient!(y, f::Counting, x) - f.gradient_count += 1 - return ProximalCore.gradient!(y, f.f, x) +function ProximalAlgorithms.value_and_pullback_function(f::Counting, x) + f.eval_count += 1 + fx, pb = ProximalAlgorithms.value_and_pullback_function(f.f, x) + function counting_pullback(v) + f.gradient_count += 1 + return pb(v) + end + return fx, counting_pullback end function ProximalCore.prox!(y, f::Counting, x, gamma) @@ -109,5 +118,6 @@ solution, iterations = panoc(x0=-ones(2), f=f, g=g) # and check how many operations where actually performed: -println(f.gradient_count) -println(g.prox_count) +println("function evals: $(f.eval_count)") +println("gradient evals: $(f.gradient_count)") +println(" prox evals: $(g.prox_count)") diff --git a/docs/src/guide/getting_started.jl b/docs/src/guide/getting_started.jl index 228df29..8d9001f 100644 --- a/docs/src/guide/getting_started.jl +++ b/docs/src/guide/getting_started.jl @@ -51,11 +51,14 @@ # which we will solve using the fast proximal gradient method (also known as fast forward-backward splitting): using LinearAlgebra +using Zygote +using AbstractDifferentiation: ZygoteBackend using ProximalOperators using ProximalAlgorithms -quadratic_cost = ProximalAlgorithms.ZygoteFunction( - x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x) +quadratic_cost = ProximalAlgorithms.AutoDifferentiable( + x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x), + ZygoteBackend() ) box_indicator = ProximalOperators.IndBox(0, 1) @@ -70,7 +73,8 @@ solution, iterations = ffb(x0=ones(2), f=quadratic_cost, g=box_indicator) # We can verify the correctness of the solution by checking that the negative gradient is orthogonal to the constraints, pointing outwards: --ProximalAlgorithms.gradient(quadratic_cost, solution)[1] +v, pullback = ProximalAlgorithms.value_and_pullback_function(quadratic_cost, solution) +-pullback(one(v))[1] # Or by plotting the solution against the cost function and constraint: From 59686f40b863eec066621b1b1ca5084f127ae2df Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 5 Jan 2024 23:41:01 +0100 Subject: [PATCH 12/25] updates --- docs/src/guide/custom_objectives.jl | 13 ++++++------- docs/src/guide/getting_started.jl | 4 ++-- src/ProximalAlgorithms.jl | 18 ++++++++++++++++-- src/algorithms/davis_yin.jl | 8 ++++---- src/algorithms/fast_forward_backward.jl | 8 ++++---- src/algorithms/forward_backward.jl | 8 ++++---- src/algorithms/li_lin.jl | 12 ++++++------ src/algorithms/panoc.jl | 16 ++++++++-------- src/algorithms/panocplus.jl | 16 ++++++++-------- src/algorithms/primal_dual.jl | 8 ++++---- src/algorithms/sfista.jl | 8 ++++---- src/algorithms/zerofpr.jl | 12 ++++++------ src/utilities/fb_tools.jl | 25 +++++++++++-------------- test/Project.toml | 1 + test/problems/test_elasticnet.jl | 2 +- test/runtests.jl | 4 ++-- test/utilities/test_ad.jl | 11 +++++------ 17 files changed, 92 insertions(+), 82 deletions(-) diff --git a/docs/src/guide/custom_objectives.jl b/docs/src/guide/custom_objectives.jl index c3cb93f..338fdc0 100644 --- a/docs/src/guide/custom_objectives.jl +++ b/docs/src/guide/custom_objectives.jl @@ -18,13 +18,12 @@ # and everything will work out of the box. # # If however one would like to provide their own gradient implementation (e.g. for efficiency reasons), -# they can simply implement a method for [`ProximalCore.gradient!`](@ref). +# they can simply implement a method for [`ProximalAlgorithms.value_and_pullback`](@ref). # # ```@docs # ProximalCore.prox # ProximalCore.prox! -# ProximalCore.gradient -# ProximalCore.gradient! +# ProximalAlgorithms.value_and_pullback # ``` # # ## Example: constrained Rosenbrock @@ -94,12 +93,12 @@ Counting(f::T) where T = Counting{T}(f, 0, 0, 0) # Now we only need to intercept any call to `gradient!` and `prox!` and increase counters there: -function ProximalAlgorithms.value_and_pullback_function(f::Counting, x) +function ProximalAlgorithms.value_and_pullback(f::Counting, x) f.eval_count += 1 - fx, pb = ProximalAlgorithms.value_and_pullback_function(f.f, x) - function counting_pullback(v) + fx, pb = ProximalAlgorithms.value_and_pullback(f.f, x) + function counting_pullback() f.gradient_count += 1 - return pb(v) + return pb() end return fx, counting_pullback end diff --git a/docs/src/guide/getting_started.jl b/docs/src/guide/getting_started.jl index 8d9001f..bb8be75 100644 --- a/docs/src/guide/getting_started.jl +++ b/docs/src/guide/getting_started.jl @@ -73,8 +73,8 @@ solution, iterations = ffb(x0=ones(2), f=quadratic_cost, g=box_indicator) # We can verify the correctness of the solution by checking that the negative gradient is orthogonal to the constraints, pointing outwards: -v, pullback = ProximalAlgorithms.value_and_pullback_function(quadratic_cost, solution) --pullback(one(v))[1] +v, pb = ProximalAlgorithms.value_and_pullback(quadratic_cost, solution) +-pb() # Or by plotting the solution against the cost function and constraint: diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index d044344..c59fec0 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -14,9 +14,23 @@ end (f::AutoDifferentiable)(x) = f.f(x) -value_and_pullback_function(f::AutoDifferentiable, x) = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x) +""" + value_and_pullback(f, x) + +Return a tuple containing the value of `f` at `x`, and the pullback function `pb`. + +Function `pb`, once called, yields the gradient of `f` at `x`. +""" +value_and_pullback -value_and_pullback_function(f::ProximalCore.Zero, x) = f(x), _ -> (zero(x)) +function value_and_pullback(f::AutoDifferentiable, x) + fx, pb = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x) + return fx, () -> pb(one(fx))[1] +end + +function value_and_pullback(f::ProximalCore.Zero, x) + f(x), () -> zero(x) +end # various utilities diff --git a/src/algorithms/davis_yin.jl b/src/algorithms/davis_yin.jl index c6fc18e..8dbfad2 100644 --- a/src/algorithms/davis_yin.jl +++ b/src/algorithms/davis_yin.jl @@ -55,8 +55,8 @@ end function Base.iterate(iter::DavisYinIteration) z = copy(iter.x0) xg, = prox(iter.g, z, iter.gamma) - f_xg, pb = value_and_pullback_function(iter.f, xg) - grad_f_xg = pb(one(f_xg))[1] + f_xg, pb = value_and_pullback(iter.f, xg) + grad_f_xg = pb() z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg xh, = prox(iter.h, z_half, iter.gamma) res = xh - xg @@ -67,8 +67,8 @@ end function Base.iterate(iter::DavisYinIteration, state::DavisYinState) prox!(state.xg, iter.g, state.z, iter.gamma) - f_xg, pb = value_and_pullback_function(iter.f, state.xg) - state.grad_f_xg .= pb(one(f_xg))[1] + f_xg, pb = value_and_pullback(iter.f, state.xg) + state.grad_f_xg .= pb() state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg prox!(state.xh, iter.h, state.z_half, iter.gamma) state.res .= state.xh .- state.xg diff --git a/src/algorithms/fast_forward_backward.jl b/src/algorithms/fast_forward_backward.jl index a6b4d7b..8413892 100644 --- a/src/algorithms/fast_forward_backward.jl +++ b/src/algorithms/fast_forward_backward.jl @@ -68,8 +68,8 @@ end function Base.iterate(iter::FastForwardBackwardIteration) x = copy(iter.x0) - f_x, pb = value_and_pullback_function(iter.f, x) - grad_f_x = pb(one(f_x))[1] + f_x, pb = value_and_pullback(iter.f, x) + grad_f_x = pb() gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x z, g_z = prox(iter.g, y, gamma) @@ -104,8 +104,8 @@ function Base.iterate(iter::FastForwardBackwardIteration{R}, state::FastForwardB state.x .= state.z .+ beta .* (state.z .- state.z_prev) state.z_prev, state.z = state.z, state.z_prev - state.f_x, pb = value_and_pullback_function(iter.f, state.x) - state.grad_f_x .= pb(one(state.f_x))[1] + state.f_x, pb = value_and_pullback(iter.f, state.x) + state.grad_f_x .= pb() state.y .= state.x .- state.gamma .* state.grad_f_x state.g_z = prox!(state.z, iter.g, state.y, state.gamma) state.res .= state.x .- state.z diff --git a/src/algorithms/forward_backward.jl b/src/algorithms/forward_backward.jl index d321251..9ca7050 100644 --- a/src/algorithms/forward_backward.jl +++ b/src/algorithms/forward_backward.jl @@ -59,8 +59,8 @@ end function Base.iterate(iter::ForwardBackwardIteration) x = copy(iter.x0) - f_x, pb = value_and_pullback_function(iter.f, x) - grad_f_x = pb(one(f_x))[1] + f_x, pb = value_and_pullback(iter.f, x) + grad_f_x = pb() gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x z, g_z = prox(iter.g, y, gamma) @@ -82,8 +82,8 @@ function Base.iterate(iter::ForwardBackwardIteration{R}, state::ForwardBackwardS state.grad_f_x, state.grad_f_z = state.grad_f_z, state.grad_f_x else state.x, state.z = state.z, state.x - state.f_x, pb = value_and_pullback_function(iter.f, state.x) - state.grad_f_x .= pb(one(state.f_x))[1] + state.f_x, pb = value_and_pullback(iter.f, state.x) + state.grad_f_x .= pb() end state.y .= state.x .- state.gamma .* state.grad_f_x diff --git a/src/algorithms/li_lin.jl b/src/algorithms/li_lin.jl index b147bbc..59c510c 100644 --- a/src/algorithms/li_lin.jl +++ b/src/algorithms/li_lin.jl @@ -62,8 +62,8 @@ end function Base.iterate(iter::LiLinIteration{R}) where {R} y = copy(iter.x0) - f_y, pb = value_and_pullback_function(iter.f, y) - grad_f_y = pb(one(f_y))[1] + f_y, pb = value_and_pullback(iter.f, y) + grad_f_y = pb() # TODO: initialize gamma if not provided # TODO: authors suggest Barzilai-Borwein rule? @@ -103,8 +103,8 @@ function Base.iterate( else # TODO: re-use available space in state? # TODO: backtrack gamma at x - f_x, pb = value_and_pullback_function(iter.f, x) - grad_f_x = pb(one(f_x))[1] + f_x, pb = value_and_pullback(iter.f, x) + grad_f_x = pb() x_forward = state.x - state.gamma .* grad_f_x v, g_v = prox(iter.g, x_forward, state.gamma) Fv = iter.f(v) + g_v @@ -123,8 +123,8 @@ function Base.iterate( Fx = Fv end - state.f_y, pb = value_and_pullback_function(iter.f, state.y) - state.grad_f_y .= pb(one(state.f_y))[1] + state.f_y, pb = value_and_pullback(iter.f, state.y) + state.grad_f_y .= pb() state.y_forward .= state.y .- state.gamma .* state.grad_f_y state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma) diff --git a/src/algorithms/panoc.jl b/src/algorithms/panoc.jl index dfa20d0..d9fe884 100644 --- a/src/algorithms/panoc.jl +++ b/src/algorithms/panoc.jl @@ -86,8 +86,8 @@ f_model(iter::PANOCIteration, state::PANOCState) = f_model(state.f_Ax, state.At_ function Base.iterate(iter::PANOCIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = value_and_pullback_function(iter.f, Ax) - grad_f_Ax = pb(one(f_Ax))[1] + f_Ax, pb = value_and_pullback(iter.f, Ax) + grad_f_Ax = pb() gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -153,8 +153,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where state.x_d .= state.x .+ state.d state.Ax_d .= state.Ax .+ state.Ad - state.f_Ax_d, pb = value_and_pullback_function(iter.f, state.Ax_d) - state.grad_f_Ax_d .= pb(one(state.f_Ax_d))[1] + state.f_Ax_d, pb = value_and_pullback(iter.f, state.Ax_d) + state.grad_f_Ax_d .= pb() mul!(state.At_grad_f_Ax_d, adjoint(iter.A), state.grad_f_Ax_d) copyto!(state.x, state.x_d) @@ -191,8 +191,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where # along a line using interpolation and linear combinations # this allows saving operations if isinf(f_Az) - f_Az, pb = value_and_pullback_function(iter.f, state.Az) - state.grad_f_Az .= pb(one(f_Az))[1] + f_Az, pb = value_and_pullback(iter.f, state.Az) + state.grad_f_Az .= pb() end if isinf(c) mul!(state.At_grad_f_Az, iter.A', state.grad_f_Az) @@ -206,8 +206,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where else # otherwise, in the general case where f is only smooth, we compute # one gradient and matvec per backtracking step - state.f_Ax, pb = value_and_pullback_function(iter.f, state.Ax) - state.grad_f_Ax .= pb(one(state.f_Ax))[1] + state.f_Ax, pb = value_and_pullback(iter.f, state.Ax) + state.grad_f_Ax .= pb() mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) end diff --git a/src/algorithms/panocplus.jl b/src/algorithms/panocplus.jl index 26633da..8b4e757 100644 --- a/src/algorithms/panocplus.jl +++ b/src/algorithms/panocplus.jl @@ -79,8 +79,8 @@ f_model(iter::PANOCplusIteration, state::PANOCplusState) = f_model(state.f_Ax, s function Base.iterate(iter::PANOCplusIteration{R}) where {R} x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = value_and_pullback_function(iter.f, Ax) - grad_f_Ax = pb(one(f_Ax))[1] + f_Ax, pb = value_and_pullback(iter.f, Ax) + grad_f_Ax = pb() gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -98,8 +98,8 @@ function Base.iterate(iter::PANOCplusIteration{R}) where {R} ) else mul!(state.Az, iter.A, state.z) - f_Az, pb = value_and_pullback_function(iter.f, state.Az) - state.grad_f_Az = pb(one(f_Az))[1] + f_Az, pb = value_and_pullback(iter.f, state.Az) + state.grad_f_Az = pb() end mul!(state.At_grad_f_Az, adjoint(iter.A), state.grad_f_Az) return state, state @@ -154,8 +154,8 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where end mul!(state.Ax, iter.A, state.x) - state.f_Ax, pb = value_and_pullback_function(iter.f, state.Ax) - state.grad_f_Ax .= pb(one(state.f_Ax))[1] + state.f_Ax, pb = value_and_pullback(iter.f, state.Ax) + state.grad_f_Ax .= pb() mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) state.y .= state.x .- state.gamma .* state.At_grad_f_Ax @@ -165,8 +165,8 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where f_Az_upp = f_model(iter, state) mul!(state.Az, iter.A, state.z) - f_Az, pb = value_and_pullback_function(iter.f, state.Az) - state.grad_f_Az .= pb(one(f_Az))[1] + f_Az, pb = value_and_pullback(iter.f, state.Az) + state.grad_f_Az .= pb() if (iter.gamma === nothing || iter.adaptive == true) tol = 10 * eps(R) * (1 + abs(f_Az)) if f_Az > f_Az_upp + tol && state.gamma >= iter.minimum_gamma diff --git a/src/algorithms/primal_dual.jl b/src/algorithms/primal_dual.jl index 3e83a90..cd7ccdb 100644 --- a/src/algorithms/primal_dual.jl +++ b/src/algorithms/primal_dual.jl @@ -167,8 +167,8 @@ end function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(iter.x0), y=copy(iter.y0))) # perform xbar-update step - f_x, pb = value_and_pullback_function(iter.f, state.x) - state.gradf .= pb(one(f_x))[1] + f_x, pb = value_and_pullback(iter.f, state.x) + state.gradf .= pb() mul!(state.temp_x, iter.L', state.y) state.temp_x .+= state.gradf state.temp_x .*= -iter.gamma[1] @@ -176,8 +176,8 @@ function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(i prox!(state.xbar, iter.g, state.temp_x, iter.gamma[1]) # perform ybar-update step - lc_y, pb = value_and_pullback_function(convex_conjugate(iter.l), state.y) - state.gradl .= pb(one(lc_y))[1] + lc_y, pb = value_and_pullback(convex_conjugate(iter.l), state.y) + state.gradl .= pb() state.temp_x .= iter.theta .* state.xbar .+ (1 - iter.theta) .* state.x mul!(state.temp_y, iter.L, state.temp_x) state.temp_y .-= state.gradl diff --git a/src/algorithms/sfista.jl b/src/algorithms/sfista.jl index 2967b1f..608127e 100644 --- a/src/algorithms/sfista.jl +++ b/src/algorithms/sfista.jl @@ -71,8 +71,8 @@ function Base.iterate( state.a = (state.τ + sqrt(state.τ ^ 2 + 4 * state.τ * state.APrev)) / 2 state.A = state.APrev + state.a state.xt .= (state.APrev / state.A) .* state.yPrev + (state.a / state.A) .* state.xPrev - f_xt, pb = value_and_pullback_function(iter.f, state.xt) - state.gradf_xt .= pb(one(f_xt))[1] + f_xt, pb = value_and_pullback(iter.f, state.xt) + state.gradf_xt .= pb() λ2 = state.λ / (1 + state.λ * iter.mf) # FISTA acceleration steps. prox!(state.y, iter.g, state.xt - λ2 * state.gradf_xt, λ2) @@ -94,8 +94,8 @@ function check_sc(state::SFISTAState, iter::SFISTAIteration, tol, termination_ty else # Classic (approximate) first-order stationary point [4]. The main inclusion is: r ∈ ∇f(y) + ∂h(y). λ2 = state.λ / (1 + state.λ * iter.mf) - f_y, pb = value_and_pullback_function(iter.f, state.y) - gradf_y = pb(one(f_y))[1] + f_y, pb = value_and_pullback(iter.f, state.y) + gradf_y = pb() r = gradf_y - state.gradf_xt + (state.xt - state.y) / λ2 res = norm(r) end diff --git a/src/algorithms/zerofpr.jl b/src/algorithms/zerofpr.jl index da247a1..0f53dd2 100644 --- a/src/algorithms/zerofpr.jl +++ b/src/algorithms/zerofpr.jl @@ -84,8 +84,8 @@ f_model(iter::ZeroFPRIteration, state::ZeroFPRState) = f_model(state.f_Ax, state function Base.iterate(iter::ZeroFPRIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = value_and_pullback_function(iter.f, Ax) - grad_f_Ax = pb(one(f_Ax))[1] + f_Ax, pb = value_and_pullback(iter.f, Ax) + grad_f_Ax = pb() gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -131,8 +131,8 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R f_Axbar_upp, f_Axbar else mul!(state.Axbar, iter.A, state.xbar) - f_Axbar, pb = value_and_pullback_function(iter.f, state.Axbar) - state.grad_f_Axbar .= pb(one(f_Axbar))[1] + f_Axbar, pb = value_and_pullback(iter.f, state.Axbar) + state.grad_f_Axbar .= pb() f_model(iter, state), f_Axbar end @@ -167,8 +167,8 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R state.x .= state.xbar_prev .+ state.tau .* state.d state.Ax .= state.Axbar .+ state.tau .* state.Ad # TODO: can precompute most of next line in case f is quadratic - state.f_Ax, pb = value_and_pullback_function(iter.f, state.Ax) - state.grad_f_Ax .= pb(one(state.f_Ax))[1] + state.f_Ax, pb = value_and_pullback(iter.f, state.Ax) + state.grad_f_Ax .= pb() mul!(state.At_grad_f_Ax, iter.A', state.grad_f_Ax) state.y .= state.x .- state.gamma .* state.At_grad_f_Ax state.g_xbar = prox!(state.xbar, iter.g, state.y, state.gamma) diff --git a/src/utilities/fb_tools.jl b/src/utilities/fb_tools.jl index 8d31506..a2c84f7 100644 --- a/src/utilities/fb_tools.jl +++ b/src/utilities/fb_tools.jl @@ -7,16 +7,16 @@ end function lower_bound_smoothness_constant(f, A, x, grad_f_Ax) R = real(eltype(x)) xeps = x .+ 1 - f_Axeps, pb = value_and_pullback_function(f, A * xeps) - grad_f_Axeps = pb(one(R))[1] + f_Axeps, pb = value_and_pullback(f, A * xeps) + grad_f_Axeps = pb() return norm(A' * (grad_f_Axeps - grad_f_Ax)) / R(sqrt(length(x))) end function lower_bound_smoothness_constant(f, A, x) R = real(eltype(x)) Ax = A * x - f_Ax, pb = value_and_pullback_function(f, Ax) - grad_f_Ax = pb(one(R))[1] + f_Ax, pb = value_and_pullback(f, Ax) + grad_f_Ax = pb() return lower_bound_smoothness_constant(f, A, x, grad_f_Ax) end @@ -29,10 +29,7 @@ function backtrack_stepsize!( ) where R f_Az_upp = f_model(f_Ax, At_grad_f_Ax, res, alpha / gamma) _mul!(Az, A, z) - f_Az, pb = value_and_pullback_function(f, Az) - if grad_f_Az !== nothing - grad_f_Az .= pb(one(f_Az))[1] - end + f_Az, pb = value_and_pullback(f, Az) tol = 10 * eps(R) * (1 + abs(f_Az)) while f_Az > f_Az_upp + tol && gamma >= minimum_gamma gamma /= 2 @@ -41,12 +38,12 @@ function backtrack_stepsize!( res .= x .- z f_Az_upp = f_model(f_Ax, At_grad_f_Ax, res, alpha / gamma) _mul!(Az, A, z) - f_Az, pb = value_and_pullback_function(f, Az) - if grad_f_Az !== nothing - grad_f_Az .= pb(one(f_Az))[1] - end + f_Az, pb = value_and_pullback(f, Az) tol = 10 * eps(R) * (1 + abs(f_Az)) end + if grad_f_Az !== nothing + grad_f_Az .= pb() + end if gamma < minimum_gamma @warn "stepsize `gamma` became too small ($(gamma))" end @@ -57,8 +54,8 @@ function backtrack_stepsize!( gamma, f, A, g, x; alpha = 1, minimum_gamma = 1e-7 ) Ax = A * x - f_Ax, pb = value_and_pullback_function(f, Ax) - grad_f_Ax = pb(one(f_Ax))[1] + f_Ax, pb = value_and_pullback(f, Ax) + grad_f_Ax = pb() At_grad_f_Ax = A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax z, g_z = prox(g, y, gamma) diff --git a/test/Project.toml b/test/Project.toml index 648b6b5..7241417 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/problems/test_elasticnet.jl b/test/problems/test_elasticnet.jl index 29e4052..ffa376a 100644 --- a/test/problems/test_elasticnet.jl +++ b/test/problems/test_elasticnet.jl @@ -54,7 +54,7 @@ using AbstractDifferentiation: ZygoteBackend afba_test_params = [ (2, 0, 130), - (1, 1, 1890), + (1, 1, 2000), (0, 1, 320), (0, 0, 194), (1, 0, 130), diff --git a/test/runtests.jl b/test/runtests.jl index 00a5e0a..eadead5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,9 +10,9 @@ end (f::Quadratic)(x) = dot(x, f.Q * x) / 2 + dot(f.q, x) -function ProximalAlgorithms.value_and_pullback_function(f::Quadratic, x) +function ProximalAlgorithms.value_and_pullback(f::Quadratic, x) grad = f.Q * x + f.q - return dot(grad, x) / 2 + dot(f.q, x), v -> (grad,) + return dot(grad, x) / 2 + dot(f.q, x), () -> grad end @testset "Aqua" begin diff --git a/test/utilities/test_ad.jl b/test/utilities/test_ad.jl index a642542..c28f097 100644 --- a/test/utilities/test_ad.jl +++ b/test/utilities/test_ad.jl @@ -4,14 +4,13 @@ using ProximalOperators: NormL1 using ProximalAlgorithms using Zygote using ReverseDiff -using ForwardDiff -using AbstractDifferentiation: ZygoteBackend, ReverseDiffBackend, ForwardDiffBackend +using AbstractDifferentiation: ZygoteBackend, ReverseDiffBackend @testset "Autodiff backend ($B on $T)" for (T, B) in Iterators.product( [Float32, Float64, ComplexF32, ComplexF64], - [ZygoteBackend, ReverseDiffBackend, ForwardDiffBackend], + [ZygoteBackend, ReverseDiffBackend], ) - if T <: Complex && B in [ReverseDiffBackend, ForwardDiffBackend] + if T <: Complex && B == ReverseDiffBackend continue end @@ -29,8 +28,8 @@ using AbstractDifferentiation: ZygoteBackend, ReverseDiffBackend, ForwardDiffBac x0 = zeros(T, n) - f_x0, pb = ProximalAlgorithms.value_and_pullback_function(f, x0) - grad_f_x0 = @inferred pb(one(R))[1] + f_x0, pb = ProximalAlgorithms.value_and_pullback(f, x0) + grad_f_x0 = @inferred pb() lam = R(0.1) * norm(A' * b, Inf) @test typeof(lam) == R From aff22b0c24ca55dcc40ebc61b22319ca1e387041 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 6 Jan 2024 13:13:26 +0100 Subject: [PATCH 13/25] update docs --- docs/src/guide/getting_started.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/guide/getting_started.jl b/docs/src/guide/getting_started.jl index bb8be75..53eb239 100644 --- a/docs/src/guide/getting_started.jl +++ b/docs/src/guide/getting_started.jl @@ -20,7 +20,7 @@ # The literature on proximal operators and algorithms is vast: for an overview, one can refer to [Parikh2014](@cite), [Beck2017](@cite). # # To evaluate these first-order primitives, in ProximalAlgorithms: -# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [Zygote](https://github.com/FluxML/Zygote.jl)). +# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl) and all of its backends). # * ``\operatorname{prox}_{f_i}`` relies on the intereface of [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15). # Both of the above can be implemented for custom function types, as [documented here](@ref custom_terms). # From 7cdcc2e338d8fc5cf3142fdc540b973eb094e3f1 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 6 Jan 2024 13:20:42 +0100 Subject: [PATCH 14/25] Update docs --- docs/src/guide/custom_objectives.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/src/guide/custom_objectives.jl b/docs/src/guide/custom_objectives.jl index 338fdc0..1fdb9d6 100644 --- a/docs/src/guide/custom_objectives.jl +++ b/docs/src/guide/custom_objectives.jl @@ -12,13 +12,13 @@ # # Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref). # -# To compute gradients, ProximalAlgorithms provides a fallback definition for [`ProximalCore.gradient!`](@ref), -# relying on [Zygote](https://github.com/FluxML/Zygote.jl) to use automatic differentiation. -# Therefore, you can provide any (differentiable) Julia function wherever gradients need to be taken, -# and everything will work out of the box. +# To compute gradients, algorithms use [`ProximalAlgorithm.value_and_pullback`](@ref): +# this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation +# with any of its supported backends, when functions are wrapped in [`ProximalAlgorithms.AutoDifferentiable`](@ref), +# as the esamples below show. # -# If however one would like to provide their own gradient implementation (e.g. for efficiency reasons), -# they can simply implement a method for [`ProximalAlgorithms.value_and_pullback`](@ref). +# If however you would like to provide your own gradient implementation (e.g. for efficiency reasons), +# you can simply implement a method for [`ProximalAlgorithms.value_and_pullback`](@ref) on your own function type. # # ```@docs # ProximalCore.prox @@ -91,7 +91,7 @@ end Counting(f::T) where T = Counting{T}(f, 0, 0, 0) -# Now we only need to intercept any call to `gradient!` and `prox!` and increase counters there: +# Now we only need to intercept any call to `value_and_pullback` and `prox!` and increase counters there: function ProximalAlgorithms.value_and_pullback(f::Counting, x) f.eval_count += 1 From b4020241d5f5bdf292b8faf900b32a2f607c736d Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 6 Jan 2024 13:50:27 +0100 Subject: [PATCH 15/25] Fixup --- docs/src/guide/custom_objectives.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/guide/custom_objectives.jl b/docs/src/guide/custom_objectives.jl index 1fdb9d6..6c460ff 100644 --- a/docs/src/guide/custom_objectives.jl +++ b/docs/src/guide/custom_objectives.jl @@ -12,7 +12,7 @@ # # Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref). # -# To compute gradients, algorithms use [`ProximalAlgorithm.value_and_pullback`](@ref): +# To compute gradients, algorithms use [`ProximalAlgorithms.value_and_pullback`](@ref): # this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation # with any of its supported backends, when functions are wrapped in [`ProximalAlgorithms.AutoDifferentiable`](@ref), # as the esamples below show. From e79ec2bcc96b2fd58ffb980ecf9e03a0421b9ece Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 6 Jan 2024 14:11:08 +0100 Subject: [PATCH 16/25] Update ProximalAlgorithms.jl --- src/ProximalAlgorithms.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index c59fec0..d717b63 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -7,6 +7,13 @@ using ProximalCore: prox, prox! const RealOrComplex{R} = Union{R,Complex{R}} const Maybe{T} = Union{T,Nothing} +""" + Autodifferentiable(f, backend) + +Construct a function from `f` to be auto-differentiated via `backend`. + +The backend can be any from AbstractDifferentiation.jl. +""" struct AutoDifferentiable{F, B} f::F backend::B From 3f93199567a5593a7079e21d0d0b8e784772fe64 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 6 Jan 2024 16:51:28 +0100 Subject: [PATCH 17/25] Fix --- src/ProximalAlgorithms.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index d717b63..471a27d 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -8,11 +8,11 @@ const RealOrComplex{R} = Union{R,Complex{R}} const Maybe{T} = Union{T,Nothing} """ - Autodifferentiable(f, backend) + AutoDifferentiable(f, backend) -Construct a function from `f` to be auto-differentiated via `backend`. +Wrap function `f` to be auto-differentiated using `backend`. -The backend can be any from AbstractDifferentiation.jl. +The backend can be any from [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl). """ struct AutoDifferentiable{F, B} f::F From 4bf6f754ac395f8f8ce7f7e45f456508e4876a90 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 6 Jan 2024 18:31:45 +0100 Subject: [PATCH 18/25] fix --- docs/src/guide/custom_objectives.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/guide/custom_objectives.jl b/docs/src/guide/custom_objectives.jl index 6c460ff..8cda6cb 100644 --- a/docs/src/guide/custom_objectives.jl +++ b/docs/src/guide/custom_objectives.jl @@ -24,6 +24,7 @@ # ProximalCore.prox # ProximalCore.prox! # ProximalAlgorithms.value_and_pullback +# ProximalAlgorithms.AutoDifferentiable # ``` # # ## Example: constrained Rosenbrock From d4076603a59c11265f2580e44e1b3913455d080b Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 6 Jan 2024 18:34:30 +0100 Subject: [PATCH 19/25] add justfile --- justfile | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 justfile diff --git a/justfile b/justfile new file mode 100644 index 0000000..81ccb7f --- /dev/null +++ b/justfile @@ -0,0 +1,14 @@ +julia: + julia --project=. + +instantiate: + julia --project=. -e 'using Pkg; Pkg.instantiate()' + +test: + julia --project=. -e 'using Pkg; Pkg.test()' + +format: + julia --project=. -e 'using JuliaFormatter: format; format(".")' + +docs: + julia --project=./docs docs/make.jl From d046528c6ea0a0c5c8ab56a77c540958ac269008 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 17 Jan 2024 17:25:29 +0100 Subject: [PATCH 20/25] Update docs/src/guide/custom_objectives.jl Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- docs/src/guide/custom_objectives.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/guide/custom_objectives.jl b/docs/src/guide/custom_objectives.jl index 8cda6cb..8146f73 100644 --- a/docs/src/guide/custom_objectives.jl +++ b/docs/src/guide/custom_objectives.jl @@ -15,7 +15,7 @@ # To compute gradients, algorithms use [`ProximalAlgorithms.value_and_pullback`](@ref): # this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation # with any of its supported backends, when functions are wrapped in [`ProximalAlgorithms.AutoDifferentiable`](@ref), -# as the esamples below show. +# as the examples below show. # # If however you would like to provide your own gradient implementation (e.g. for efficiency reasons), # you can simply implement a method for [`ProximalAlgorithms.value_and_pullback`](@ref) on your own function type. From 2e062c340b136bfe934e375f271ed3d8cc6c7563 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 20 Jan 2024 16:59:31 +0100 Subject: [PATCH 21/25] address comments --- docs/src/guide/custom_objectives.jl | 12 ++++++------ docs/src/guide/getting_started.jl | 5 +++-- src/ProximalAlgorithms.jl | 16 +++++++++------- src/algorithms/davis_yin.jl | 8 ++++---- src/algorithms/fast_forward_backward.jl | 8 ++++---- src/algorithms/forward_backward.jl | 8 ++++---- src/algorithms/li_lin.jl | 12 ++++++------ src/algorithms/panoc.jl | 16 ++++++++-------- src/algorithms/panocplus.jl | 16 ++++++++-------- src/algorithms/primal_dual.jl | 8 ++++---- src/algorithms/sfista.jl | 8 ++++---- src/algorithms/zerofpr.jl | 12 ++++++------ src/utilities/fb_tools.jl | 18 +++++++++--------- test/runtests.jl | 2 +- test/utilities/test_ad.jl | 4 ++-- 15 files changed, 78 insertions(+), 75 deletions(-) diff --git a/docs/src/guide/custom_objectives.jl b/docs/src/guide/custom_objectives.jl index 8146f73..5189ddb 100644 --- a/docs/src/guide/custom_objectives.jl +++ b/docs/src/guide/custom_objectives.jl @@ -12,18 +12,18 @@ # # Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref). # -# To compute gradients, algorithms use [`ProximalAlgorithms.value_and_pullback`](@ref): +# To compute gradients, algorithms use [`ProximalAlgorithms.value_and_gradient_closure`](@ref): # this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation # with any of its supported backends, when functions are wrapped in [`ProximalAlgorithms.AutoDifferentiable`](@ref), # as the examples below show. # # If however you would like to provide your own gradient implementation (e.g. for efficiency reasons), -# you can simply implement a method for [`ProximalAlgorithms.value_and_pullback`](@ref) on your own function type. +# you can simply implement a method for [`ProximalAlgorithms.value_and_gradient_closure`](@ref) on your own function type. # # ```@docs # ProximalCore.prox # ProximalCore.prox! -# ProximalAlgorithms.value_and_pullback +# ProximalAlgorithms.value_and_gradient_closure # ProximalAlgorithms.AutoDifferentiable # ``` # @@ -92,11 +92,11 @@ end Counting(f::T) where T = Counting{T}(f, 0, 0, 0) -# Now we only need to intercept any call to `value_and_pullback` and `prox!` and increase counters there: +# Now we only need to intercept any call to `value_and_gradient_closure` and `prox!` and increase counters there: -function ProximalAlgorithms.value_and_pullback(f::Counting, x) +function ProximalAlgorithms.value_and_gradient_closure(f::Counting, x) f.eval_count += 1 - fx, pb = ProximalAlgorithms.value_and_pullback(f.f, x) + fx, pb = ProximalAlgorithms.value_and_gradient_closure(f.f, x) function counting_pullback() f.gradient_count += 1 return pb() diff --git a/docs/src/guide/getting_started.jl b/docs/src/guide/getting_started.jl index 53eb239..1333bf4 100644 --- a/docs/src/guide/getting_started.jl +++ b/docs/src/guide/getting_started.jl @@ -72,9 +72,10 @@ ffb = ProximalAlgorithms.FastForwardBackward(maxit=1000, tol=1e-5, verbose=true) solution, iterations = ffb(x0=ones(2), f=quadratic_cost, g=box_indicator) # We can verify the correctness of the solution by checking that the negative gradient is orthogonal to the constraints, pointing outwards: +# for this, we just evaluate the closure `cl` returned as second output of [`value_and_gradient_closure`](@ref). -v, pb = ProximalAlgorithms.value_and_pullback(quadratic_cost, solution) --pb() +v, cl = ProximalAlgorithms.value_and_gradient_closure(quadratic_cost, solution) +-cl() # Or by plotting the solution against the cost function and constraint: diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index 471a27d..2494da6 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -10,8 +10,10 @@ const Maybe{T} = Union{T,Nothing} """ AutoDifferentiable(f, backend) -Wrap function `f` to be auto-differentiated using `backend`. +Callable struct wrapping function `f` to be auto-differentiated using `backend`. +When called, it evaluates the same as `f`, while [`ProximalAlgorithms.value_and_gradient_closure`](@ref) +is implemented using `backend` for automatic differentiation. The backend can be any from [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl). """ struct AutoDifferentiable{F, B} @@ -22,20 +24,20 @@ end (f::AutoDifferentiable)(x) = f.f(x) """ - value_and_pullback(f, x) + value_and_gradient_closure(f, x) -Return a tuple containing the value of `f` at `x`, and the pullback function `pb`. +Return a tuple containing the value of `f` at `x`, and a closure `cl`. -Function `pb`, once called, yields the gradient of `f` at `x`. +Function `cl`, once called, yields the gradient of `f` at `x`. """ -value_and_pullback +value_and_gradient_closure -function value_and_pullback(f::AutoDifferentiable, x) +function value_and_gradient_closure(f::AutoDifferentiable, x) fx, pb = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x) return fx, () -> pb(one(fx))[1] end -function value_and_pullback(f::ProximalCore.Zero, x) +function value_and_gradient_closure(f::ProximalCore.Zero, x) f(x), () -> zero(x) end diff --git a/src/algorithms/davis_yin.jl b/src/algorithms/davis_yin.jl index 8dbfad2..01b362b 100644 --- a/src/algorithms/davis_yin.jl +++ b/src/algorithms/davis_yin.jl @@ -55,8 +55,8 @@ end function Base.iterate(iter::DavisYinIteration) z = copy(iter.x0) xg, = prox(iter.g, z, iter.gamma) - f_xg, pb = value_and_pullback(iter.f, xg) - grad_f_xg = pb() + f_xg, cl = value_and_gradient_closure(iter.f, xg) + grad_f_xg = cl() z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg xh, = prox(iter.h, z_half, iter.gamma) res = xh - xg @@ -67,8 +67,8 @@ end function Base.iterate(iter::DavisYinIteration, state::DavisYinState) prox!(state.xg, iter.g, state.z, iter.gamma) - f_xg, pb = value_and_pullback(iter.f, state.xg) - state.grad_f_xg .= pb() + f_xg, cl = value_and_gradient_closure(iter.f, state.xg) + state.grad_f_xg .= cl() state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg prox!(state.xh, iter.h, state.z_half, iter.gamma) state.res .= state.xh .- state.xg diff --git a/src/algorithms/fast_forward_backward.jl b/src/algorithms/fast_forward_backward.jl index 8413892..916511c 100644 --- a/src/algorithms/fast_forward_backward.jl +++ b/src/algorithms/fast_forward_backward.jl @@ -68,8 +68,8 @@ end function Base.iterate(iter::FastForwardBackwardIteration) x = copy(iter.x0) - f_x, pb = value_and_pullback(iter.f, x) - grad_f_x = pb() + f_x, cl = value_and_gradient_closure(iter.f, x) + grad_f_x = cl() gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x z, g_z = prox(iter.g, y, gamma) @@ -104,8 +104,8 @@ function Base.iterate(iter::FastForwardBackwardIteration{R}, state::FastForwardB state.x .= state.z .+ beta .* (state.z .- state.z_prev) state.z_prev, state.z = state.z, state.z_prev - state.f_x, pb = value_and_pullback(iter.f, state.x) - state.grad_f_x .= pb() + state.f_x, cl = value_and_gradient_closure(iter.f, state.x) + state.grad_f_x .= cl() state.y .= state.x .- state.gamma .* state.grad_f_x state.g_z = prox!(state.z, iter.g, state.y, state.gamma) state.res .= state.x .- state.z diff --git a/src/algorithms/forward_backward.jl b/src/algorithms/forward_backward.jl index 9ca7050..15b6040 100644 --- a/src/algorithms/forward_backward.jl +++ b/src/algorithms/forward_backward.jl @@ -59,8 +59,8 @@ end function Base.iterate(iter::ForwardBackwardIteration) x = copy(iter.x0) - f_x, pb = value_and_pullback(iter.f, x) - grad_f_x = pb() + f_x, cl = value_and_gradient_closure(iter.f, x) + grad_f_x = cl() gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma y = x - gamma .* grad_f_x z, g_z = prox(iter.g, y, gamma) @@ -82,8 +82,8 @@ function Base.iterate(iter::ForwardBackwardIteration{R}, state::ForwardBackwardS state.grad_f_x, state.grad_f_z = state.grad_f_z, state.grad_f_x else state.x, state.z = state.z, state.x - state.f_x, pb = value_and_pullback(iter.f, state.x) - state.grad_f_x .= pb() + state.f_x, cl = value_and_gradient_closure(iter.f, state.x) + state.grad_f_x .= cl() end state.y .= state.x .- state.gamma .* state.grad_f_x diff --git a/src/algorithms/li_lin.jl b/src/algorithms/li_lin.jl index 59c510c..fc7e367 100644 --- a/src/algorithms/li_lin.jl +++ b/src/algorithms/li_lin.jl @@ -62,8 +62,8 @@ end function Base.iterate(iter::LiLinIteration{R}) where {R} y = copy(iter.x0) - f_y, pb = value_and_pullback(iter.f, y) - grad_f_y = pb() + f_y, cl = value_and_gradient_closure(iter.f, y) + grad_f_y = cl() # TODO: initialize gamma if not provided # TODO: authors suggest Barzilai-Borwein rule? @@ -103,8 +103,8 @@ function Base.iterate( else # TODO: re-use available space in state? # TODO: backtrack gamma at x - f_x, pb = value_and_pullback(iter.f, x) - grad_f_x = pb() + f_x, cl = value_and_gradient_closure(iter.f, x) + grad_f_x = cl() x_forward = state.x - state.gamma .* grad_f_x v, g_v = prox(iter.g, x_forward, state.gamma) Fv = iter.f(v) + g_v @@ -123,8 +123,8 @@ function Base.iterate( Fx = Fv end - state.f_y, pb = value_and_pullback(iter.f, state.y) - state.grad_f_y .= pb() + state.f_y, cl = value_and_gradient_closure(iter.f, state.y) + state.grad_f_y .= cl() state.y_forward .= state.y .- state.gamma .* state.grad_f_y state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma) diff --git a/src/algorithms/panoc.jl b/src/algorithms/panoc.jl index d9fe884..72414ff 100644 --- a/src/algorithms/panoc.jl +++ b/src/algorithms/panoc.jl @@ -86,8 +86,8 @@ f_model(iter::PANOCIteration, state::PANOCState) = f_model(state.f_Ax, state.At_ function Base.iterate(iter::PANOCIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = value_and_pullback(iter.f, Ax) - grad_f_Ax = pb() + f_Ax, cl = value_and_gradient_closure(iter.f, Ax) + grad_f_Ax = cl() gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -153,8 +153,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where state.x_d .= state.x .+ state.d state.Ax_d .= state.Ax .+ state.Ad - state.f_Ax_d, pb = value_and_pullback(iter.f, state.Ax_d) - state.grad_f_Ax_d .= pb() + state.f_Ax_d, cl = value_and_gradient_closure(iter.f, state.Ax_d) + state.grad_f_Ax_d .= cl() mul!(state.At_grad_f_Ax_d, adjoint(iter.A), state.grad_f_Ax_d) copyto!(state.x, state.x_d) @@ -191,8 +191,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where # along a line using interpolation and linear combinations # this allows saving operations if isinf(f_Az) - f_Az, pb = value_and_pullback(iter.f, state.Az) - state.grad_f_Az .= pb() + f_Az, cl = value_and_gradient_closure(iter.f, state.Az) + state.grad_f_Az .= cl() end if isinf(c) mul!(state.At_grad_f_Az, iter.A', state.grad_f_Az) @@ -206,8 +206,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where else # otherwise, in the general case where f is only smooth, we compute # one gradient and matvec per backtracking step - state.f_Ax, pb = value_and_pullback(iter.f, state.Ax) - state.grad_f_Ax .= pb() + state.f_Ax, cl = value_and_gradient_closure(iter.f, state.Ax) + state.grad_f_Ax .= cl() mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) end diff --git a/src/algorithms/panocplus.jl b/src/algorithms/panocplus.jl index 8b4e757..2b33e4d 100644 --- a/src/algorithms/panocplus.jl +++ b/src/algorithms/panocplus.jl @@ -79,8 +79,8 @@ f_model(iter::PANOCplusIteration, state::PANOCplusState) = f_model(state.f_Ax, s function Base.iterate(iter::PANOCplusIteration{R}) where {R} x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = value_and_pullback(iter.f, Ax) - grad_f_Ax = pb() + f_Ax, cl = value_and_gradient_closure(iter.f, Ax) + grad_f_Ax = cl() gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -98,8 +98,8 @@ function Base.iterate(iter::PANOCplusIteration{R}) where {R} ) else mul!(state.Az, iter.A, state.z) - f_Az, pb = value_and_pullback(iter.f, state.Az) - state.grad_f_Az = pb() + f_Az, cl = value_and_gradient_closure(iter.f, state.Az) + state.grad_f_Az = cl() end mul!(state.At_grad_f_Az, adjoint(iter.A), state.grad_f_Az) return state, state @@ -154,8 +154,8 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where end mul!(state.Ax, iter.A, state.x) - state.f_Ax, pb = value_and_pullback(iter.f, state.Ax) - state.grad_f_Ax .= pb() + state.f_Ax, cl = value_and_gradient_closure(iter.f, state.Ax) + state.grad_f_Ax .= cl() mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax) state.y .= state.x .- state.gamma .* state.At_grad_f_Ax @@ -165,8 +165,8 @@ function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where f_Az_upp = f_model(iter, state) mul!(state.Az, iter.A, state.z) - f_Az, pb = value_and_pullback(iter.f, state.Az) - state.grad_f_Az .= pb() + f_Az, cl = value_and_gradient_closure(iter.f, state.Az) + state.grad_f_Az .= cl() if (iter.gamma === nothing || iter.adaptive == true) tol = 10 * eps(R) * (1 + abs(f_Az)) if f_Az > f_Az_upp + tol && state.gamma >= iter.minimum_gamma diff --git a/src/algorithms/primal_dual.jl b/src/algorithms/primal_dual.jl index cd7ccdb..d9e9074 100644 --- a/src/algorithms/primal_dual.jl +++ b/src/algorithms/primal_dual.jl @@ -167,8 +167,8 @@ end function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(iter.x0), y=copy(iter.y0))) # perform xbar-update step - f_x, pb = value_and_pullback(iter.f, state.x) - state.gradf .= pb() + f_x, cl = value_and_gradient_closure(iter.f, state.x) + state.gradf .= cl() mul!(state.temp_x, iter.L', state.y) state.temp_x .+= state.gradf state.temp_x .*= -iter.gamma[1] @@ -176,8 +176,8 @@ function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(i prox!(state.xbar, iter.g, state.temp_x, iter.gamma[1]) # perform ybar-update step - lc_y, pb = value_and_pullback(convex_conjugate(iter.l), state.y) - state.gradl .= pb() + lc_y, cl = value_and_gradient_closure(convex_conjugate(iter.l), state.y) + state.gradl .= cl() state.temp_x .= iter.theta .* state.xbar .+ (1 - iter.theta) .* state.x mul!(state.temp_y, iter.L, state.temp_x) state.temp_y .-= state.gradl diff --git a/src/algorithms/sfista.jl b/src/algorithms/sfista.jl index 608127e..c66e1a8 100644 --- a/src/algorithms/sfista.jl +++ b/src/algorithms/sfista.jl @@ -71,8 +71,8 @@ function Base.iterate( state.a = (state.τ + sqrt(state.τ ^ 2 + 4 * state.τ * state.APrev)) / 2 state.A = state.APrev + state.a state.xt .= (state.APrev / state.A) .* state.yPrev + (state.a / state.A) .* state.xPrev - f_xt, pb = value_and_pullback(iter.f, state.xt) - state.gradf_xt .= pb() + f_xt, cl = value_and_gradient_closure(iter.f, state.xt) + state.gradf_xt .= cl() λ2 = state.λ / (1 + state.λ * iter.mf) # FISTA acceleration steps. prox!(state.y, iter.g, state.xt - λ2 * state.gradf_xt, λ2) @@ -94,8 +94,8 @@ function check_sc(state::SFISTAState, iter::SFISTAIteration, tol, termination_ty else # Classic (approximate) first-order stationary point [4]. The main inclusion is: r ∈ ∇f(y) + ∂h(y). λ2 = state.λ / (1 + state.λ * iter.mf) - f_y, pb = value_and_pullback(iter.f, state.y) - gradf_y = pb() + f_y, cl = value_and_gradient_closure(iter.f, state.y) + gradf_y = cl() r = gradf_y - state.gradf_xt + (state.xt - state.y) / λ2 res = norm(r) end diff --git a/src/algorithms/zerofpr.jl b/src/algorithms/zerofpr.jl index 0f53dd2..6008cc4 100644 --- a/src/algorithms/zerofpr.jl +++ b/src/algorithms/zerofpr.jl @@ -84,8 +84,8 @@ f_model(iter::ZeroFPRIteration, state::ZeroFPRState) = f_model(state.f_Ax, state function Base.iterate(iter::ZeroFPRIteration{R}) where R x = copy(iter.x0) Ax = iter.A * x - f_Ax, pb = value_and_pullback(iter.f, Ax) - grad_f_Ax = pb() + f_Ax, cl = value_and_gradient_closure(iter.f, Ax) + grad_f_Ax = cl() gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma At_grad_f_Ax = iter.A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax @@ -131,8 +131,8 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R f_Axbar_upp, f_Axbar else mul!(state.Axbar, iter.A, state.xbar) - f_Axbar, pb = value_and_pullback(iter.f, state.Axbar) - state.grad_f_Axbar .= pb() + f_Axbar, cl = value_and_gradient_closure(iter.f, state.Axbar) + state.grad_f_Axbar .= cl() f_model(iter, state), f_Axbar end @@ -167,8 +167,8 @@ function Base.iterate(iter::ZeroFPRIteration{R}, state::ZeroFPRState) where R state.x .= state.xbar_prev .+ state.tau .* state.d state.Ax .= state.Axbar .+ state.tau .* state.Ad # TODO: can precompute most of next line in case f is quadratic - state.f_Ax, pb = value_and_pullback(iter.f, state.Ax) - state.grad_f_Ax .= pb() + state.f_Ax, cl = value_and_gradient_closure(iter.f, state.Ax) + state.grad_f_Ax .= cl() mul!(state.At_grad_f_Ax, iter.A', state.grad_f_Ax) state.y .= state.x .- state.gamma .* state.At_grad_f_Ax state.g_xbar = prox!(state.xbar, iter.g, state.y, state.gamma) diff --git a/src/utilities/fb_tools.jl b/src/utilities/fb_tools.jl index a2c84f7..e2acff4 100644 --- a/src/utilities/fb_tools.jl +++ b/src/utilities/fb_tools.jl @@ -7,16 +7,16 @@ end function lower_bound_smoothness_constant(f, A, x, grad_f_Ax) R = real(eltype(x)) xeps = x .+ 1 - f_Axeps, pb = value_and_pullback(f, A * xeps) - grad_f_Axeps = pb() + f_Axeps, cl = value_and_gradient_closure(f, A * xeps) + grad_f_Axeps = cl() return norm(A' * (grad_f_Axeps - grad_f_Ax)) / R(sqrt(length(x))) end function lower_bound_smoothness_constant(f, A, x) R = real(eltype(x)) Ax = A * x - f_Ax, pb = value_and_pullback(f, Ax) - grad_f_Ax = pb() + f_Ax, cl = value_and_gradient_closure(f, Ax) + grad_f_Ax = cl() return lower_bound_smoothness_constant(f, A, x, grad_f_Ax) end @@ -29,7 +29,7 @@ function backtrack_stepsize!( ) where R f_Az_upp = f_model(f_Ax, At_grad_f_Ax, res, alpha / gamma) _mul!(Az, A, z) - f_Az, pb = value_and_pullback(f, Az) + f_Az, cl = value_and_gradient_closure(f, Az) tol = 10 * eps(R) * (1 + abs(f_Az)) while f_Az > f_Az_upp + tol && gamma >= minimum_gamma gamma /= 2 @@ -38,11 +38,11 @@ function backtrack_stepsize!( res .= x .- z f_Az_upp = f_model(f_Ax, At_grad_f_Ax, res, alpha / gamma) _mul!(Az, A, z) - f_Az, pb = value_and_pullback(f, Az) + f_Az, cl = value_and_gradient_closure(f, Az) tol = 10 * eps(R) * (1 + abs(f_Az)) end if grad_f_Az !== nothing - grad_f_Az .= pb() + grad_f_Az .= cl() end if gamma < minimum_gamma @warn "stepsize `gamma` became too small ($(gamma))" @@ -54,8 +54,8 @@ function backtrack_stepsize!( gamma, f, A, g, x; alpha = 1, minimum_gamma = 1e-7 ) Ax = A * x - f_Ax, pb = value_and_pullback(f, Ax) - grad_f_Ax = pb() + f_Ax, cl = value_and_gradient_closure(f, Ax) + grad_f_Ax = cl() At_grad_f_Ax = A' * grad_f_Ax y = x - gamma .* At_grad_f_Ax z, g_z = prox(g, y, gamma) diff --git a/test/runtests.jl b/test/runtests.jl index eadead5..8380741 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,7 @@ end (f::Quadratic)(x) = dot(x, f.Q * x) / 2 + dot(f.q, x) -function ProximalAlgorithms.value_and_pullback(f::Quadratic, x) +function ProximalAlgorithms.value_and_gradient_closure(f::Quadratic, x) grad = f.Q * x + f.q return dot(grad, x) / 2 + dot(f.q, x), () -> grad end diff --git a/test/utilities/test_ad.jl b/test/utilities/test_ad.jl index c28f097..25c15b3 100644 --- a/test/utilities/test_ad.jl +++ b/test/utilities/test_ad.jl @@ -28,8 +28,8 @@ using AbstractDifferentiation: ZygoteBackend, ReverseDiffBackend x0 = zeros(T, n) - f_x0, pb = ProximalAlgorithms.value_and_pullback(f, x0) - grad_f_x0 = @inferred pb() + f_x0, cl = ProximalAlgorithms.value_and_gradient_closure(f, x0) + grad_f_x0 = @inferred cl() lam = R(0.1) * norm(A' * b, Inf) @test typeof(lam) == R From 9ac83c01cc04c3da6d198fd0b013a5375ab448f3 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 20 Jan 2024 18:28:05 +0100 Subject: [PATCH 22/25] update benchmark --- benchmark/benchmarks.jl | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index a63736d..74a3dff 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -8,6 +8,22 @@ using FileIO const SUITE = BenchmarkGroup() +function ProximalAlgorithms.value_and_gradient_closure(f::ProximalOperators.LeastSquaresDirect, x) + res = f.A*x - f.b + norm(res)^2, () -> f.A'*res +end + +struct SquaredDistance{Tb} + b::Tb +end + +(f::SquaredDistance)(x) = norm(x - f.b)^2 + +function ProximalAlgorithms.value_and_gradient_closure(f::SquaredDistance, x) + diff = x - f.b + norm(diff)^2, () -> diff +end + for (benchmark_name, file_name) in [ ("Lasso tiny", joinpath(@__DIR__, "data", "lasso_tiny.jld2")), ("Lasso small", joinpath(@__DIR__, "data", "lasso_small.jld2")), @@ -42,21 +58,21 @@ for (benchmark_name, file_name) in [ SUITE[k]["ZeroFPR"] = @benchmarkable solver(x0=x0, f=f, A=$A, g=g) setup=begin solver = ProximalAlgorithms.ZeroFPR(tol=1e-6) x0 = zeros($T, size($A, 2)) - f = Translate(SqrNormL2(), -$b) + f = SquaredDistance($b) g = NormL1($lam) end SUITE[k]["PANOC"] = @benchmarkable solver(x0=x0, f=f, A=$A, g=g) setup=begin solver = ProximalAlgorithms.PANOC(tol=1e-6) x0 = zeros($T, size($A, 2)) - f = Translate(SqrNormL2(), -$b) + f = SquaredDistance($b) g = NormL1($lam) end SUITE[k]["PANOCplus"] = @benchmarkable solver(x0=x0, f=f, A=$A, g=g) setup=begin solver = ProximalAlgorithms.PANOCplus(tol=1e-6) x0 = zeros($T, size($A, 2)) - f = Translate(SqrNormL2(), -$b) + f = SquaredDistance($b) g = NormL1($lam) end From 2d8d611e3006fdc05fff6acccb22dd156540801a Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sun, 21 Jan 2024 10:49:21 +0100 Subject: [PATCH 23/25] just benchmark --- justfile | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/justfile b/justfile index 81ccb7f..5dc42fa 100644 --- a/justfile +++ b/justfile @@ -12,3 +12,8 @@ format: docs: julia --project=./docs docs/make.jl + +benchmark: + julia --project=benchmark -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + julia --project=benchmark benchmark/runbenchmarks.jl + From fd98409e508c54817bc31eaacdb16f8e82ce79ad Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sun, 21 Jan 2024 11:02:49 +0100 Subject: [PATCH 24/25] update readme and docs --- README.md | 21 +++++++++++++-------- docs/src/index.md | 29 +++++++++++++---------------- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 15129be..bfff729 100644 --- a/README.md +++ b/README.md @@ -11,14 +11,19 @@ A Julia package for non-smooth optimization algorithms. This package provides algorithms for the minimization of objective functions that include non-smooth terms, such as constraints or non-differentiable penalties. Implemented algorithms include: -* (Fast) Proximal gradient methods -* Douglas-Rachford splitting -* Three-term splitting -* Primal-dual splitting algorithms -* Newton-type methods - -This package works well in combination with [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15), -which contains a wide range of functions that can be used to express cost terms. +- (Fast) Proximal gradient methods +- Douglas-Rachford splitting +- Three-term splitting +- Primal-dual splitting algorithms +- Newton-type methods + +Algorithms rely on: +- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation +(but you can easily bring your own gradients) +- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc, +to handle non-differentiable terms +(see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) +for an extensive collection of functions). ## Documentation diff --git a/docs/src/index.md b/docs/src/index.md index 87b95ef..b1a119f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -5,16 +5,21 @@ A Julia package for non-smooth optimization algorithms. [Link to GitHub reposito This package provides algorithms for the minimization of objective functions that include non-smooth terms, such as constraints or non-differentiable penalties. Implemented algorithms include: -* (Fast) Proximal gradient methods -* Douglas-Rachford splitting -* Three-term splitting -* Primal-dual splitting algorithms -* Newton-type methods +- (Fast) Proximal gradient methods +- Douglas-Rachford splitting +- Three-term splitting +- Primal-dual splitting algorithms +- Newton-type methods Check out [this section](@ref problems_algorithms) for an overview of the available algorithms. -This package works well in combination with [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15), -which contains a wide range of functions that can be used to express cost terms. +Algorithms rely on: +- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation +(but you can easily bring your own gradients) +- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc, +to handle non-differentiable terms +(see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) +for an extensive collection of functions). !!! note @@ -23,20 +28,11 @@ which contains a wide range of functions that can be used to express cost terms. ## Installation -Install the latest stable release with - ```julia julia> ] pkg> add ProximalAlgorithms ``` -To install the development version instead (`master` branch), do - -```julia -julia> ] -pkg> add ProximalAlgorithms#master -``` - ## Citing If you use any of the algorithms from ProximalAlgorithms in your research, you are kindly asked to cite the relevant bibliography. @@ -45,3 +41,4 @@ Please check [this section of the manual](@ref problems_algorithms) for algorith ## Contributing Contributions are welcome in the form of [issue notifications](https://github.com/JuliaFirstOrder/ProximalAlgorithms.jl/issues) or [pull requests](https://github.com/JuliaFirstOrder/ProximalAlgorithms.jl/pulls). When contributing new algorithms, we highly recommend looking at already implemented ones to get inspiration on how to structure the code. + From 68a80a0e225574257560e294c6fb3c1821e49ae5 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sun, 21 Jan 2024 11:55:44 +0100 Subject: [PATCH 25/25] update readme --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index bfff729..525ed7b 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,8 @@ Implemented algorithms include: - Primal-dual splitting algorithms - Newton-type methods +Check out [this section](https://juliafirstorder.github.io/ProximalAlgorithms.jl/stable/guide/implemented_algorithms/) for an overview of the available algorithms. + Algorithms rely on: - [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation (but you can easily bring your own gradients)