Skip to content

Commit 4db6042

Browse files
committed
aggregation algorithms fix
1 parent aac7bd5 commit 4db6042

File tree

3 files changed

+53
-46
lines changed

3 files changed

+53
-46
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AsynchronousIterativeAlgorithms"
22
uuid = "329f2bf2-1773-4f40-9abf-1830ae341a86"
33
authors = ["selim-chraibi"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

docs/src/manual.md

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ We saw how to run an asynchronous version of the SGD algorithm on a LRMSE proble
55
- [Working with a distributed problem](@ref)
66
- [Synchronous run](@ref)
77
- [Active processes](@ref)
8-
- [Recording iterates](@ref)
8+
- [Recording iterates](@ref recording_iterated)
99
- [Customization of `start`'s execution](@ref custom_execution)
1010
- [Handling worker failures](@ref)
1111
- [Algorithm wrappers](@ref algorithm_wrappers)
@@ -262,7 +262,7 @@ $$q_j \longleftarrow \textrm{query}(\underset{i \in \textrm{connected}}{\textrm{
262262

263263
where $q_j$ is computed by the worker upon reception of $\textrm{answer}(q_i)$ from worker $j$ and where $connected$ are the list of workers that have answered.
264264

265-
The [`AggregationAlgorithm`](@ref) in this library requires you to specify three methods: query, answer, and aggregate. Here's an example showing the required signatures of these three methods:
265+
The [`AggregationAlgorithm`](@ref) in this library requires you to define four methods: `initialize`, `query`, `answer`, and `aggregate`. Here's an example showing the required signatures of these three methods:
266266

267267
```julia
268268
@everywhere begin
@@ -273,10 +273,10 @@ The [`AggregationAlgorithm`](@ref) in this library requires you to specify three
273273
stepsize::Float64
274274
end
275275

276-
(tba::ToBeAggregatedGD)(problem::Any) = tba.q1
277-
(tba::ToBeAggregatedGD)(a::Vector{Vector{Float64}}, connected::Vector{Int64}) = mean(a)
278-
(tba::ToBeAggregatedGD)(a::Vector{Float64}, problem::Any) = a
279-
(tba::ToBeAggregatedGD)(q::Vector{Float64}, problem::Any) = q - tba.stepsize * problem.∇f(q)
276+
AIA.initialize(tba::ToBeAggregatedGD, problem::Any) = tba.q1
277+
AIA.aggregate(tba::ToBeAggregatedGD, a::Vector{Vector{Float64}}, connected::Vector{Int64}) = mean(a)
278+
AIA.query(tba::ToBeAggregatedGD, a::Vector{Float64}, problem::Any) = a
279+
AIA.answer(tba::ToBeAggregatedGD, q::Vector{Float64}, problem::Any) = q - tba.stepsize * problem.∇f(q)
280280
end
281281

282282
algorithm = AggregationAlgorithm(ToBeAggregatedGD(rand(10), 0.01); pids=workers())
@@ -286,7 +286,7 @@ history = start(algorithm, distributed_problem, (epoch=100,));
286286

287287
**Memory limitation:** At any point in time, the central worker should have access must have access to the latest answers $a_i$ from *all* the connected workers. This means storing a lot of $a_i$ if we use many workers. There is a workaround when the aggregation operation is an *average*. In this case, only the equivalent of one answer needs to be saved on the central node, regardless of the number of workers.
288288

289-
[`AveragingAlgorithm`](@ref) implements this memory optimization. Here you only need to define `query`, the `answer`
289+
[`AveragingAlgorithm`](@ref) implements this memory optimization. Here you only need to define `initialize`, `query`, the `answer`
290290

291291
```julia
292292
@everywhere begin
@@ -295,9 +295,9 @@ history = start(algorithm, distributed_problem, (epoch=100,));
295295
stepsize::Float64
296296
end
297297

298-
(tba::ToBeAveragedGD)(problem::Any) = tba.q1
299-
(tba::ToBeAveragedGD)(a::Vector{Float64}, problem::Any) = a
300-
(tba::ToBeAveragedGD)(q::Vector{Float64}, problem::Any) = q - tba.stepsize * problem.∇f(q)
298+
AIA.initialize(tba::ToBeAveragedGD, problem::Any) = tba.q1
299+
AIA.query(tba::ToBeAveragedGD, a::Vector{Float64}, problem::Any) = a
300+
AIA.answer(tba::ToBeAveragedGD, q::Vector{Float64}, problem::Any) = q - tba.stepsize * problem.∇f(q)
301301
end
302302

303303
algorithm = AveragingAlgorithm(ToBeAveragedGD(rand(10), 0.01); pids=workers(), weights=ones(nworkers()))
@@ -311,5 +311,6 @@ Note that you can implement the [custom callbacks](@ref custom_execution) on bot
311311
report(::ToBeAggregatedGD) = # do something
312312
```
313313

314+
---
314315

315-
Hope you find this library helpful and look forward to seeing how you put it to use!
316+
Wow you read all this! Hope you find this library helpful and look forward to seeing how you put it to use!

src/algorithm_wrappers.jl

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
export AggregationAlgorithm, AveragingAlgorithm
22

3+
initialize(algorithm::AbstractAlgorithm, problem::Any) = throw(ArgumentError("Method initialize(::$(typeof(algorithm)), problem::Any) not implemented."))
4+
aggregate(algorithm::AbstractAlgorithm{Q,A}, as::Vector{A}, workers::Vector{Int64}) where {Q,A} = throw(ArgumentError("Method aggregate(::$(typeof(algorithm)), problem::Any) not implemented."))
5+
query(algorithm::AbstractAlgorithm, agg::AggregatedA, problem::Any) where {AggregatedA} = throw(ArgumentError("Method query(::$(typeof(algorithm)), problem::Any) not implemented."))
6+
answer(algorithm::AbstractAlgorithm{Q,A}, q::Q, problem::Any) where {Q,A} = throw(ArgumentError("Method answer(::$(typeof(algorithm)), problem::Any) not implemented."))
7+
38
"""
4-
AggregationAlgorithm{Q,A,Alg<:AbstractAlgorithm{Q,A}}(arg; kwarg)::AbstractAlgorithm{Q,A} where {Q,A}
9+
AggregationAlgorithm(arg; kwarg)::AbstractAlgorithm
510
611
Distributed algorithm that writes: `q_j <- query(aggregate([answer(q_i) for i in connected]))`
712
Where a "connected" worker is a worker that has answered at least once.
813
(Not memory optimized: `length(pids)` answers are stored on the central worker at all times)
914
1015
# Argument
11-
- `algorithm<:AbstractAlgorithm{Q,A}` which should define the following
12-
- `algorithm(problem::Any)::Q`: the initialization step that create the first query iterate
13-
- `algorithm(as::Vector{A}, workers::Vector{Int64})::AggregatedA` where A: the answer aggregarion step performed by the central node when receiving the answers `as::Vector{A}` from the `workers`
14-
- `algorithm(agg::AggregatedA, problem::Any)::Q`: the query step producing a query from the aggregated answer `agg::AggregatedA`, performed by the central node
15-
- `algorithm(q::Q, problem::Any)::A`: the answer step perfromed by the wokers when they receive a query `q::Q` from the central node
16+
- `algorithm<:AbstractAlgorithm{Q,A}` which should define the following (where `const AIA = AsynchronousIterativeAlgorithms`)
17+
- `AIA.initialize(algorithm, problem::Any)::Q`: step that creates the first query iterate
18+
- `AIA.aggregate(algorithm, as::Vector{A}, workers::Vector{Int64})::AggregatedA` where A: step performed by the central node when receiving the answers `as::Vector{A}` from the `workers`
19+
- `AIA.query(algorithm, agg::AggregatedA, problem::Any)::Q`: step producing a query from the aggregated answer `agg::AggregatedA`, performed by the central node
20+
- `AIA.answer(algorithm, q::Q, problem::Any)::A`: step perfromed by the wokers when they receive a query `q::Q` from the central node
1621
1722
# Keyword
1823
- `pids=workers()`: `pids` of the active workers
@@ -30,41 +35,41 @@ struct AggregationAlgorithm{Q,A,Alg<:AbstractAlgorithm{Q,A}} <: AbstractAlgorith
3035
end
3136

3237
"""
33-
(::AggregationAlgorithm{Q,A,Alg})(problem::Any)::Q where {Q,A,Alg}
38+
(::AggregationAlgorithm{Q,A,Alg})(problem::Any)::Q where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
3439
3540
The initialization step that create the first query iterate
3641
"""
3742
function (agg::AggregationAlgorithm)(problem::Any)
38-
agg.algorithm(problem)
43+
initialize(agg.algorithm, problem)
3944
end
4045

4146
"""
42-
(::AggregationAlgorithm{Q,A,Alg})(a::A, worker::Int64, problem::Any)::Q where {Q,A,Alg}
47+
(::AggregationAlgorithm{Q,A,Alg})(a::A, worker::Int64, problem::Any)::Q where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
4348
4449
Asynchronous step performed by the central node when receiving an answer `a::A` from a worker
4550
"""
46-
function (agg::AggregationAlgorithm{Q,A,Alg})(a::A, worker::Int64, problem::Any) where {Q,A,Alg}
51+
function (agg::AggregationAlgorithm{Q,A,Alg})(a::A, worker::Int64, problem::Any) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
4752
agg.connected[worker] = true
4853
agg.answers[worker] = a
49-
agg.algorithm(agg.algorithm(agg.answers[agg.connected], (1:maximum(agg.pids))[agg.connected]), problem)
54+
query(agg.algorithm, aggregate(agg.algorithm, agg.answers[agg.connected], (1:maximum(agg.pids))[agg.connected]), problem)
5055
end
5156

5257
"""
53-
(::AggregationAlgorithm{Q,A,Alg})(as::Vector{A}, workers::Vector{Int64}, problem::Any)::Q where {Q,A,Alg}
58+
(::AggregationAlgorithm{Q,A,Alg})(as::Vector{A}, workers::Vector{Int64}, problem::Any)::Q where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
5459
5560
Synchronous step performed by the central node when receiving answers `as::Vector{A}` respectively from `workers::Vector{Int64}`
5661
"""
57-
function (agg::AggregationAlgorithm{Q,A,Alg})(as::Vector{A}, workers::Vector{Int64}, problem::Any) where {Q,A,Alg}
58-
agg.algorithm(agg.algorithm(as, workers), problem)
62+
function (agg::AggregationAlgorithm{Q,A,Alg})(as::Vector{A}, workers::Vector{Int64}, problem::Any) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
63+
query(agg.algorithm, aggregate(agg.algorithm, as, workers), problem)
5964
end
6065

6166
"""
62-
(::AggregationAlgorithm{Q,A,Alg})(q::Q, problem::Any)->A where {Q,A,Alg}
67+
(::AggregationAlgorithm{Q,A,Alg})(q::Q, problem::Any)->A where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
6368
6469
Steps performed by the workers when they receive a query `q::Q` from the central node
6570
"""
66-
function (agg::AggregationAlgorithm{Q,A,Alg})(q::Q, problem::Any) where {Q,A,Alg}
67-
agg.algorithm(q, problem)
71+
function (agg::AggregationAlgorithm{Q,A,Alg})(q::Q, problem::Any) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
72+
answer(agg.algorithm, q, problem)
6873
end
6974

7075
stopnow(agg::AggregationAlgorithm, stopat::NamedTuple) = stopnow(agg.algorithm, stopat)
@@ -76,20 +81,21 @@ savevalues(agg::AggregationAlgorithm) = savevalues(agg.algorithm)
7681

7782

7883
"""
79-
AveragingAlgorithm{Q,A,Alg<:AbstractAlgorithm{Q,A}}(arg; kwarg)::AbstractAlgorithm{Q,A} where {Q,A}
84+
AveragingAlgorithm(arg; kwarg)::AbstractAlgorithm
8085
8186
Distributed algorithm that writes: `q_j <- query(weighted_average([answer(q_i) for i in connected]))`
8287
Where a "connected" worker is a worker that has answered at least once.
8388
(Memory optimized: only the equivalent of one answer is stored on the central worker at all times)
8489
8590
# Argument
86-
- `algorithm<:AbstractAlgorithm{Q,A}` which should define the following
87-
- `algorithm(problem::Any)::Q`: the initialization step that create the first query iterate
88-
- `algorithm(a::A, problem::Any)::Q`: the query step producing a query from the averaged answer, performed by the central node
89-
- `algorithm(q::Q, problem::Any)::A`: the answer step perfromed by the wokers when they receive a query `q::Q` from the central node
91+
- `algorithm<:AbstractAlgorithm{Q,A}` which should define the following (where `const AIA = AsynchronousIterativeAlgorithms`)
92+
- `AIA.initialize(algorithm, problem::Any)::Q`: step that creates the first query iterate
93+
- `AIA.query(algorithm, a::A, problem::Any)::Q`: step producing a query from the averaged answer, performed by the central node
94+
- `AIA.answer(algorithm, q::Q, problem::Any)::A`: step perfromed by the wokers when they receive a query `q::Q` from the central node
9095
9196
# Keyword
9297
- `pids=workers()`: `pids` of the active workers
98+
- `weights=ones(length(pids))`: weights of each pid in the weighted average
9399
"""
94100
mutable struct AveragingAlgorithm{Q,A,Alg<:AbstractAlgorithm{Q,A}} <: AbstractAlgorithm{Q,A}
95101
pids::Vector{Int64}
@@ -100,7 +106,7 @@ mutable struct AveragingAlgorithm{Q,A,Alg<:AbstractAlgorithm{Q,A}} <: AbstractAl
100106
last_answers::Vector{A}
101107
last_average::Union{A,Nothing}
102108
weights::Vector{Float64}
103-
function AveragingAlgorithm(algorithm::Alg; pids=procs(), weights=ones(nprocs())) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
109+
function AveragingAlgorithm(algorithm::Alg; pids=procs(), weights=ones(length(pids))) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
104110
@assert length(pids) == length(weights) "There should be as many weights as there are pids"
105111
maxpid = maximum(pids)
106112
connected = BitVector(zeros(maxpid))
@@ -118,44 +124,44 @@ mutable struct AveragingAlgorithm{Q,A,Alg<:AbstractAlgorithm{Q,A}} <: AbstractAl
118124
end
119125

120126
"""
121-
(::AveragingAlgorithm{Q,A,Alg})(problem::Any) where {Q,A,Alg}
127+
(::AveragingAlgorithm{Q,A,Alg})(problem::Any) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
122128
123129
The initialization step that create the first query iterate
124130
"""
125131
function (avg::AveragingAlgorithm)(problem::Any)
126-
avg.algorithm(problem)
132+
initialize(avg.algorithm, problem)
127133
end
128134

129135
"""
130-
(::AveragingAlgorithm{Q,A,Alg})(δas::Vector{A}, workers::Vector{Int64}, problem::Any) where {Q,A,Alg}
136+
(::AveragingAlgorithm{Q,A,Alg})(δas::Vector{A}, workers::Vector{Int64}, problem::Any) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
131137
132138
Asynchronous step performed by the central node when receiving an answer `a::A` from a worker.
133139
"""
134-
function (avg::AveragingAlgorithm{Q,A,Alg})(δa::A, worker::Int64, problem::Any) where {Q,A,Alg}
140+
function (avg::AveragingAlgorithm{Q,A,Alg})(δa::A, worker::Int64, problem::Any) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
135141
avg.connected[worker] = true
136142
normalization = sum(avg.connected .* avg.weights)
137143
avg.last_average = isnothing(avg.last_average) ? δa : (avg.weights[worker] * δa + avg.last_normalization * avg.last_average) / normalization
138144
avg.last_normalization = normalization
139-
avg.algorithm(avg.last_average, problem)
145+
query(avg.algorithm, avg.last_average, problem)
140146
end
141147

142148
"""
143-
(::AveragingAlgorithm{Q,A,Alg})(δas::Vector{A}, workers::Vector{Int64}, problem::Any) where {Q,A,Alg}
149+
(::AveragingAlgorithm{Q,A,Alg})(δas::Vector{A}, workers::Vector{Int64}, problem::Any) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
144150
145151
Synchronous step performed by the central node when receiving answers `a::Vector{A}` respectively from `workers::Vector{Int64}`
146152
"""
147-
function (avg::AveragingAlgorithm{Q,A,Alg})(δas::Vector{A}, workers::Vector{Int64}, problem::Any) where {Q,A,Alg}
153+
function (avg::AveragingAlgorithm{Q,A,Alg})(δas::Vector{A}, workers::Vector{Int64}, problem::Any) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
148154
avg.last_average = isnothing(avg.last_average) ? δas : sum(avg.weights[workers] * δas) / sum(avg.weights) + avg.last_average
149-
avg.algorithm(avg.last_average, problem)
155+
query(avg.algorithm, avg.last_average, problem)
150156
end
151157

152158
"""
153-
(::AveragingAlgorithm{Q,A,Alg})(q::Q, problem::Any) where {Q,A,Alg}
159+
(::AveragingAlgorithm{Q,A,Alg})(q::Q, problem::Any) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
154160
155161
Steps performed by the workers when they receive a query `q::Q` from the central node
156162
"""
157-
function (avg::AveragingAlgorithm{Q,A,Alg})(q::Q, problem::Any) where {Q,A,Alg}
158-
a = avg.algorithm(q, problem)
163+
function (avg::AveragingAlgorithm{Q,A,Alg})(q::Q, problem::Any) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
164+
a = answer(avg.algorithm, q, problem)
159165
δa = isnothing(avg.last_answer) ? a : a - avg.last_answer
160166
avg.last_answer = a
161167
return δa

0 commit comments

Comments
 (0)