Skip to content

Commit 832d6cd

Browse files
committed
Improve verbose mode
- `summary` function added to the `IterativeAlgorithm` struct. This function returns a tuple of pairs where the key is the column title. E.g.: ("" => it, , "f(xg)" => state.f_xg, ...) - The `display` function is modified to call summary and display the result. When `it = 0` is passed, then only a table header is printed. - When `freq` in `IterativeAlgorithm` is set to 0, then only a single line is printed after the iteration stops. The format of this line is like: "total iterations = 43, f(xg) = 3.524e-3, ..." - `default_display` function now accepts `printfunc` optional argument. The default value is `println`, and this argument allows replacing it, e.g., with a proper logger.
1 parent 01f4995 commit 832d6cd

File tree

12 files changed

+209
-115
lines changed

12 files changed

+209
-115
lines changed

src/ProximalAlgorithms.jl

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ADTypes: ADTypes
44
using DifferentiationInterface: DifferentiationInterface
55
using ProximalCore
66
using ProximalCore: prox, prox!
7+
using Printf
78

89
const RealOrComplex{R} = Union{R,Complex{R}}
910
const Maybe{T} = Union{T,Nothing}
@@ -55,18 +56,19 @@ include("accel/noaccel.jl")
5556

5657
# algorithm interface
5758

58-
struct IterativeAlgorithm{IteratorType,H,S,D,K}
59+
struct IterativeAlgorithm{IteratorType,H,S,I,D,K}
5960
maxit::Int
6061
stop::H
6162
solution::S
6263
verbose::Bool
6364
freq::Int
65+
summary::I
6466
display::D
6567
kwargs::K
6668
end
6769

6870
"""
69-
IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, display, kwargs...)
71+
IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, summary, display, kwargs...)
7072
7173
Wrapper for an iterator type `T`, adding termination and verbosity options on top of it.
7274
@@ -75,7 +77,7 @@ The resulting "algorithm" object `alg` can be called on a set of keyword argumen
7577
to `kwargs` and passed on to `T` to construct an iterator which will be looped over.
7678
Specifically, if an algorithm is constructed as
7779
78-
alg = IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, display, kwargs...)
80+
alg = IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, summary, display, kwargs...)
7981
8082
then calling it with
8183
@@ -88,7 +90,7 @@ will internally loop over an iterator constructed as
8890
# Note
8991
This constructor is not meant to be used directly: instead, algorithm-specific constructors
9092
should be defined on top of it and exposed to the user, that set appropriate default functions
91-
for `stop`, `solution`, `display`.
93+
for `stop`, `solution`, `summary`, `display`.
9294
9395
# Arguments
9496
* `T::Type`: iterator type to use
@@ -97,28 +99,78 @@ for `stop`, `solution`, `display`.
9799
* `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution
98100
* `verbose::Bool`: whether the algorithm state should be displayed
99101
* `freq::Int`: every how many iterations to display the algorithm state
100-
* `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
102+
* `summary::Function`: function returning a summary of the iteration state, `summary(k::Int, iter::T, state)` should return a vector of pairs `(name, value)`
103+
* `display::Function`: display function, `display(k::Int, alg, iter::T, state)` should display a summary of the iteration state
101104
* `kwargs...`: keyword arguments to pass on to `T` when constructing the iterator
102105
"""
103-
IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, display, kwargs...) =
104-
IterativeAlgorithm{T,typeof(stop),typeof(solution),typeof(display),typeof(kwargs)}(
106+
IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, summary, display, kwargs...) =
107+
IterativeAlgorithm{T,typeof(stop),typeof(solution),typeof(summary),typeof(display),typeof(kwargs)}(
105108
maxit,
106109
stop,
107110
solution,
108111
verbose,
109112
freq,
113+
summary,
110114
display,
111115
kwargs,
112116
)
113117

118+
function default_display(k, alg, iter, state, printfunc=println)
119+
if alg.freq > 0
120+
summary = alg.summary(k, iter, state)
121+
column_widths = map(pair -> max(length(pair.first), pair.second isa Integer ? 5 : 9), summary)
122+
if k == 0
123+
keys = map(first, summary)
124+
first_line = [_get_centered_text(key, width) for (width, key) in zip(column_widths, keys)]
125+
printfunc(join(first_line, " | "))
126+
second_line = [repeat('-', width) for width in column_widths]
127+
printfunc(join(second_line, "-|-"), "-")
128+
else
129+
values = map(last, summary)
130+
parts = [_format_value(value, width) for (width, value) in zip(column_widths, values)]
131+
printfunc(join(parts, " | "))
132+
end
133+
else
134+
summary = alg.summary(k, iter, state)
135+
if summary[1].first == ""
136+
summary = ("total iterations" => k, summary[2:end]...)
137+
end
138+
items = map(pair -> @sprintf("%s=%s", pair.first, _format_value(pair.second, 0)), summary)
139+
printfunc(join(items, ", "))
140+
end
141+
end
142+
143+
function _get_centered_text(text, width)
144+
l = length(text)
145+
if l >= width
146+
return text
147+
end
148+
left_padding = div(width - l, 2)
149+
right_padding = width - l - left_padding
150+
return repeat(" ", left_padding) * text * repeat(" ", right_padding)
151+
end
152+
153+
function _format_value(value, width)
154+
if value isa Integer
155+
return @sprintf("%*d", width, value)
156+
elseif value isa Float64 || value isa Float32
157+
return @sprintf("%*.3e", width, value)
158+
else
159+
return @sprintf("%*s", width, string(value))
160+
end
161+
end
162+
114163
function (alg::IterativeAlgorithm{IteratorType})(; kwargs...) where {IteratorType}
115164
iter = IteratorType(; alg.kwargs..., kwargs...)
116165
for (k, state) in enumerate(iter)
166+
if k == 1 && alg.verbose && alg.freq > 0
167+
alg.display(0, alg, iter, state)
168+
end
117169
if k >= alg.maxit || alg.stop(iter, state)
118-
alg.verbose && alg.display(k, iter, state)
170+
alg.verbose && alg.display(k, alg, iter, state)
119171
return (alg.solution(iter, state), k)
120172
end
121-
alg.verbose && mod(k, alg.freq) == 0 && alg.display(k, iter, state)
173+
alg.verbose && alg.freq > 0 && mod(k, alg.freq) == 0 && alg.display(k, alg, iter, state)
122174
end
123175
end
124176

src/algorithms/davis_yin.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ end
4444

4545
Base.IteratorSize(::Type{<:DavisYinIteration}) = Base.IsInfinite()
4646

47-
struct DavisYinState{T}
47+
struct DavisYinState{T,R}
4848
z::T
4949
xg::T
50+
f_xg::R
5051
grad_f_xg::T
5152
z_half::T
5253
xh::T
54+
g_xh::R
5355
res::T
5456
end
5557

@@ -58,10 +60,10 @@ function Base.iterate(iter::DavisYinIteration)
5860
xg, = prox(iter.g, z, iter.gamma)
5961
f_xg, grad_f_xg = value_and_gradient(iter.f, xg)
6062
z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg
61-
xh, = prox(iter.h, z_half, iter.gamma)
63+
xh, g_xh = prox(iter.h, z_half, iter.gamma)
6264
res = xh - xg
6365
z .+= iter.lambda .* res
64-
state = DavisYinState(z, xg, grad_f_xg, z_half, xh, res)
66+
state = DavisYinState(z, xg, f_xg, grad_f_xg, z_half, xh, g_xh, res)
6567
return state, state
6668
end
6769

@@ -79,8 +81,8 @@ end
7981
default_stopping_criterion(tol, ::DavisYinIteration, state::DavisYinState) =
8082
norm(state.res, Inf) <= tol
8183
default_solution(::DavisYinIteration, state::DavisYinState) = state.xh
82-
default_display(it, ::DavisYinIteration, state::DavisYinState) =
83-
@printf("%5d | %.3e\n", it, norm(state.res, Inf))
84+
default_iteration_summary(it, ::DavisYinIteration, state::DavisYinState) =
85+
("" => it, "f(xg)" => state.f_xg, "g(xh)" => state.g_xh, "‖xg - xh‖" => norm(state.res, Inf))
8486

8587
"""
8688
DavisYin(; <keyword-arguments>)
@@ -101,11 +103,12 @@ See also: [`DavisYinIteration`](@ref), [`IterativeAlgorithm`](@ref).
101103
# Arguments
102104
- `maxit::Int=10_000`: maximum number of iteration
103105
- `tol::1e-8`: tolerance for the default stopping criterion
104-
- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
105-
- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution
106+
- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
107+
- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution
106108
- `verbose::Bool=false`: whether the algorithm state should be displayed
107-
- `freq::Int=100`: every how many iterations to display the algorithm state
108-
- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
109+
- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed.
110+
- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state
111+
- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
109112
- `kwargs...`: additional keyword arguments to pass on to the `DavisYinIteration` constructor upon call
110113
111114
# References
@@ -118,6 +121,7 @@ DavisYin(;
118121
solution = default_solution,
119122
verbose = false,
120123
freq = 100,
124+
summary=default_iteration_summary,
121125
display = default_display,
122126
kwargs...,
123127
) = IterativeAlgorithm(
@@ -127,6 +131,7 @@ DavisYin(;
127131
solution,
128132
verbose,
129133
freq,
134+
summary,
130135
display,
131136
kwargs...,
132137
)

src/algorithms/douglas_rachford.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,26 @@ end
4242

4343
Base.IteratorSize(::Type{<:DouglasRachfordIteration}) = Base.IsInfinite()
4444

45-
Base.@kwdef struct DouglasRachfordState{Tx}
45+
Base.@kwdef struct DouglasRachfordState{Tx,R}
4646
x::Tx
4747
y::Tx = similar(x)
48+
f_y::R = real(eltype(x))(0)
4849
r::Tx = similar(x)
4950
z::Tx = similar(x)
51+
g_z::R = real(eltype(x))(0)
5052
res::Tx = similar(x)
5153
end
5254

5355
function Base.iterate(
5456
iter::DouglasRachfordIteration,
5557
state::DouglasRachfordState = DouglasRachfordState(x = copy(iter.x0)),
5658
)
57-
prox!(state.y, iter.f, state.x, iter.gamma)
59+
f_y = prox!(state.y, iter.f, state.x, iter.gamma)
5860
state.r .= 2 .* state.y .- state.x
59-
prox!(state.z, iter.g, state.r, iter.gamma)
61+
g_z = prox!(state.z, iter.g, state.r, iter.gamma)
6062
state.res .= state.y .- state.z
6163
state.x .-= state.res
64+
state = DouglasRachfordState(state.x, state.y, f_y, state.r, state.z, g_z, state.res)
6265
return state, state
6366
end
6467

@@ -68,8 +71,8 @@ default_stopping_criterion(
6871
state::DouglasRachfordState,
6972
) = norm(state.res, Inf) / iter.gamma <= tol
7073
default_solution(::DouglasRachfordIteration, state::DouglasRachfordState) = state.y
71-
default_display(it, iter::DouglasRachfordIteration, state::DouglasRachfordState) =
72-
@printf("%5d | %.3e\n", it, norm(state.res, Inf) / iter.gamma)
74+
default_iteration_summary(it, iter::DouglasRachfordIteration, state::DouglasRachfordState) =
75+
("" => it, "f(y)" => state.f_y, "g(z)" => state.g_z, "‖y - z‖" => norm(state.res, Inf) / iter.gamma)
7376

7477
"""
7578
DouglasRachford(; <keyword-arguments>)
@@ -88,11 +91,12 @@ See also: [`DouglasRachfordIteration`](@ref), [`IterativeAlgorithm`](@ref).
8891
# Arguments
8992
- `maxit::Int=1_000`: maximum number of iteration
9093
- `tol::1e-8`: tolerance for the default stopping criterion
91-
- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
92-
- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution
94+
- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
95+
- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution
9396
- `verbose::Bool=false`: whether the algorithm state should be displayed
94-
- `freq::Int=100`: every how many iterations to display the algorithm state
95-
- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
97+
- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed.
98+
- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state
99+
- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
96100
- `kwargs...`: additional keyword arguments to pass on to the `DouglasRachfordIteration` constructor upon call
97101
98102
# References
@@ -105,6 +109,7 @@ DouglasRachford(;
105109
solution = default_solution,
106110
verbose = false,
107111
freq = 100,
112+
summary = default_iteration_summary,
108113
display = default_display,
109114
kwargs...,
110115
) = IterativeAlgorithm(
@@ -114,6 +119,7 @@ DouglasRachford(;
114119
solution,
115120
verbose,
116121
freq,
122+
summary,
117123
display,
118124
kwargs...,
119125
)

src/algorithms/drls.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,13 @@ end
197197
default_stopping_criterion(tol, ::DRLSIteration, state::DRLSState) =
198198
norm(state.res, Inf) / state.gamma <= tol
199199
default_solution(::DRLSIteration, state::DRLSState) = state.v
200-
default_display(it, ::DRLSIteration, state::DRLSState) = @printf(
201-
"%5d | %.3e | %.3e | %.3e\n",
202-
it,
203-
state.gamma,
204-
norm(state.res, Inf) / state.gamma,
205-
state.tau,
206-
)
200+
default_iteration_summary(it, ::DRLSIteration, state::DRLSState) =
201+
("" => it,
202+
"f(u)" => state.f_u,
203+
"g(v)" => state.g_v,
204+
"γ" => state.gamma,
205+
"‖u - v‖/γ" => norm(state.res, Inf) / state.gamma,
206+
"τ" => state.tau)
207207

208208
"""
209209
DRLS(; <keyword-arguments>)
@@ -224,11 +224,12 @@ See also: [`DRLSIteration`](@ref), [`IterativeAlgorithm`](@ref).
224224
# Arguments
225225
- `maxit::Int=1_000`: maximum number of iteration
226226
- `tol::1e-8`: tolerance for the default stopping criterion
227-
- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
228-
- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution
227+
- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
228+
- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution
229229
- `verbose::Bool=false`: whether the algorithm state should be displayed
230-
- `freq::Int=10`: every how many iterations to display the algorithm state
231-
- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
230+
- `freq::Int=10`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed.
231+
- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state
232+
- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
232233
- `kwargs...`: additional keyword arguments to pass on to the `DRLSIteration` constructor upon call
233234
234235
# References
@@ -241,6 +242,7 @@ DRLS(;
241242
solution = default_solution,
242243
verbose = false,
243244
freq = 10,
245+
summary = default_iteration_summary,
244246
display = default_display,
245247
kwargs...,
246248
) = IterativeAlgorithm(
@@ -250,6 +252,7 @@ DRLS(;
250252
solution,
251253
verbose,
252254
freq,
255+
summary,
253256
display,
254257
kwargs...,
255258
)

src/algorithms/fast_forward_backward.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,13 @@ default_stopping_criterion(
150150
state::FastForwardBackwardState,
151151
) = norm(state.res, Inf) / state.gamma <= tol
152152
default_solution(::FastForwardBackwardIteration, state::FastForwardBackwardState) = state.z
153-
default_display(it, ::FastForwardBackwardIteration, state::FastForwardBackwardState) =
154-
@printf("%5d | %.3e | %.3e\n", it, state.gamma, norm(state.res, Inf) / state.gamma)
153+
default_iteration_summary(it, iter::FastForwardBackwardIteration, state::FastForwardBackwardState) = begin
154+
if iter.adaptive
155+
("" => it, "f(x)" => state.f_x, "g(z)" => state.g_z, "γ" => state.gamma, "‖x - z‖/γ" => norm(state.res, Inf) / state.gamma)
156+
else
157+
("" => it, "f(x)" => state.f_x, "g(z)" => state.g_z, "‖x - z‖/γ" => norm(state.res, Inf) / state.gamma)
158+
end
159+
end
155160

156161
"""
157162
FastForwardBackward(; <keyword-arguments>)
@@ -172,11 +177,12 @@ See also: [`FastForwardBackwardIteration`](@ref), [`IterativeAlgorithm`](@ref).
172177
# Arguments
173178
- `maxit::Int=10_000`: maximum number of iteration
174179
- `tol::1e-8`: tolerance for the default stopping criterion
175-
- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
176-
- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution
180+
- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
181+
- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution
177182
- `verbose::Bool=false`: whether the algorithm state should be displayed
178-
- `freq::Int=100`: every how many iterations to display the algorithm state
179-
- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
183+
- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed.
184+
- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state
185+
- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
180186
- `kwargs...`: additional keyword arguments to pass on to the `FastForwardBackwardIteration` constructor upon call
181187
182188
# References
@@ -190,6 +196,7 @@ FastForwardBackward(;
190196
solution = default_solution,
191197
verbose = false,
192198
freq = 100,
199+
summary = default_iteration_summary,
193200
display = default_display,
194201
kwargs...,
195202
) = IterativeAlgorithm(
@@ -199,6 +206,7 @@ FastForwardBackward(;
199206
solution,
200207
verbose,
201208
freq,
209+
summary,
202210
display,
203211
kwargs...,
204212
)

0 commit comments

Comments
 (0)