Skip to content

Commit b7e1037

Browse files
committed
custom execution
1 parent bdc8dc9 commit b7e1037

File tree

1 file changed

+169
-0
lines changed

1 file changed

+169
-0
lines changed

src/algorithm_wrappers.jl

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
export AggregationAlgorithm, AveragingAlgorithm
2+
3+
"""
4+
AggregationAlgorithm{Q,A,Alg<:AbstractAlgorithm{Q,A}}(arg; kwarg)::AbstractAlgorithm{Q,A} where {Q,A}
5+
6+
Distributed algorithm that writes: `q_j <- query(aggregate([answer(q_i) for i in connected]))`
7+
Where a "connected" worker is a worker that has answered at least once.
8+
(Not memory optimized: `length(pids)` answers are stored on the central worker at all times)
9+
10+
# Argument
11+
- `algorithm<:AbstractAlgorithm{Q,A}` which should define the following
12+
- `algorithm(problem::Any)::Q`: the initialization step that create the first query iterate
13+
- `algorithm(as::Vector{A}, workers::Vector{Int64})::AggregatedA` where A: the answer aggregarion step performed by the central node when receiving the answers `as::Vector{A}` from the `workers`
14+
- `algorithm(agg::AggregatedA, problem::Any)::Q`: the query step producing a query from the aggregated answer `agg::AggregatedA`, performed by the central node
15+
- `algorithm(q::Q, problem::Any)::A`: the answer step perfromed by the wokers when they receive a query `q::Q` from the central node
16+
17+
# Keyword
18+
- `pids=workers()`: `pids` of the active workers
19+
"""
20+
struct AggregationAlgorithm{Q,A,Alg<:AbstractAlgorithm{Q,A}} <: AbstractAlgorithm{Q,A}
21+
algorithm::Alg
22+
pids::Vector{Int64}
23+
answers::Vector{A}
24+
connected::BitVector
25+
function AggregationAlgorithm(algorithm::Alg; pids=workers()) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
26+
connected = BitVector(zeros(maximum(pids)))
27+
answers = Vector{A}(undef, maximum(pids))
28+
new{Q,A,Alg}(algorithm, pids, answers, connected)
29+
end
30+
end
31+
32+
"""
33+
(::AggregationAlgorithm{Q,A,Alg})(problem::Any)::Q where {Q,A,Alg}
34+
35+
The initialization step that create the first query iterate
36+
"""
37+
function (agg::AggregationAlgorithm)(problem::Any)
38+
agg.algorithm(problem)
39+
end
40+
41+
"""
42+
(::AggregationAlgorithm{Q,A,Alg})(a::A, worker::Int64, problem::Any)::Q where {Q,A,Alg}
43+
44+
Asynchronous step performed by the central node when receiving an answer `a::A` from a worker
45+
"""
46+
function (agg::AggregationAlgorithm{Q,A,Alg})(a::A, worker::Int64, problem::Any) where {Q,A,Alg}
47+
agg.connected[worker] = true
48+
agg.answers[worker] = a
49+
agg.algorithm(agg.algorithm(agg.answers[agg.connected], (1:maximum(agg.pids))[agg.connected]), problem)
50+
end
51+
52+
"""
53+
(::AggregationAlgorithm{Q,A,Alg})(as::Vector{A}, workers::Vector{Int64}, problem::Any)::Q where {Q,A,Alg}
54+
55+
Synchronous step performed by the central node when receiving answers `as::Vector{A}` respectively from `workers::Vector{Int64}`
56+
"""
57+
function (agg::AggregationAlgorithm{Q,A,Alg})(as::Vector{A}, workers::Vector{Int64}, problem::Any) where {Q,A,Alg}
58+
agg.algorithm(agg.algorithm(as, workers), problem)
59+
end
60+
61+
"""
62+
(::AggregationAlgorithm{Q,A,Alg})(q::Q, problem::Any)->A where {Q,A,Alg}
63+
64+
Steps performed by the workers when they receive a query `q::Q` from the central node
65+
"""
66+
function (agg::AggregationAlgorithm{Q,A,Alg})(q::Q, problem::Any) where {Q,A,Alg}
67+
agg.algorithm(q, problem)
68+
end
69+
70+
stopnow(agg::AggregationAlgorithm, stopat::NamedTuple) = stopnow(agg.algorithm, stopat)
71+
showvalues(agg::AggregationAlgorithm) = showvalues(agg.algorithm)
72+
report(agg::AggregationAlgorithm) = report(agg.algorithm)
73+
progress(agg::AggregationAlgorithm, stopat::NamedTuple) = progress(agg.algorithm, stopat)
74+
savenow(agg::AggregationAlgorithm, saveat::NamedTuple) = savenow(agg.algorithm, saveat)
75+
savevalues(agg::AggregationAlgorithm) = savevalues(agg.algorithm)
76+
77+
78+
"""
79+
AveragingAlgorithm{Q,A,Alg<:AbstractAlgorithm{Q,A}}(arg; kwarg)::AbstractAlgorithm{Q,A} where {Q,A}
80+
81+
Distributed algorithm that writes: `q_j <- query(weighted_average([answer(q_i) for i in connected]))`
82+
Where a "connected" worker is a worker that has answered at least once.
83+
(Memory optimized: only the equivalent of one answer is stored on the central worker at all times)
84+
85+
# Argument
86+
- `algorithm<:AbstractAlgorithm{Q,A}` which should define the following
87+
- `algorithm(problem::Any)::Q`: the initialization step that create the first query iterate
88+
- `algorithm(a::A, problem::Any)::Q`: the query step producing a query from the averaged answer, performed by the central node
89+
- `algorithm(q::Q, problem::Any)::A`: the answer step perfromed by the wokers when they receive a query `q::Q` from the central node
90+
91+
# Keyword
92+
- `pids=workers()`: `pids` of the active workers
93+
"""
94+
mutable struct AveragingAlgorithm{Q,A,Alg<:AbstractAlgorithm{Q,A}} <: AbstractAlgorithm{Q,A}
95+
pids::Vector{Int64}
96+
algorithm::Alg
97+
connected::BitVector
98+
last_normalization::Float64
99+
last_answer::Union{A,Nothing}
100+
last_answers::Vector{A}
101+
last_average::Union{A,Nothing}
102+
weights::Vector{Float64}
103+
function AveragingAlgorithm(algorithm::Alg; pids=procs(), weights=ones(nprocs())) where Alg<:AbstractAlgorithm{Q,A} where {Q,A}
104+
@assert length(pids) == length(weights) "There should be as many weights as there are pids"
105+
maxpid = maximum(pids)
106+
connected = BitVector(zeros(maxpid))
107+
sparse_weights = zeros(maxpid)
108+
for (pid, weight) in zip(pids, weights)
109+
sparse_weights[pid] = weight
110+
end
111+
last_normalization = 1.0
112+
last_answer = nothing
113+
last_answers = Vector{A}(undef, maximum(pids))
114+
last_average = nothing
115+
116+
new{Q,A,Alg}(pids, algorithm, connected, last_normalization, last_answer, last_answers, last_average, sparse_weights)
117+
end
118+
end
119+
120+
"""
121+
(::AveragingAlgorithm{Q,A,Alg})(problem::Any) where {Q,A,Alg}
122+
123+
The initialization step that create the first query iterate
124+
"""
125+
function (avg::AveragingAlgorithm)(problem::Any)
126+
avg.algorithm(problem)
127+
end
128+
129+
"""
130+
(::AveragingAlgorithm{Q,A,Alg})(δas::Vector{A}, workers::Vector{Int64}, problem::Any) where {Q,A,Alg}
131+
132+
Asynchronous step performed by the central node when receiving an answer `a::A` from a worker.
133+
"""
134+
function (avg::AveragingAlgorithm{Q,A,Alg})(δa::A, worker::Int64, problem::Any) where {Q,A,Alg}
135+
avg.connected[worker] = true
136+
normalization = sum(avg.connected .* avg.weights)
137+
avg.last_average = isnothing(avg.last_average) ? δa : (avg.weights[worker] * δa + avg.last_normalization * avg.last_average) / normalization
138+
avg.last_normalization = normalization
139+
avg.algorithm(avg.last_average, problem)
140+
end
141+
142+
"""
143+
(::AveragingAlgorithm{Q,A,Alg})(δas::Vector{A}, workers::Vector{Int64}, problem::Any) where {Q,A,Alg}
144+
145+
Synchronous step performed by the central node when receiving answers `a::Vector{A}` respectively from `workers::Vector{Int64}`
146+
"""
147+
function (avg::AveragingAlgorithm{Q,A,Alg})(δas::Vector{A}, workers::Vector{Int64}, problem::Any) where {Q,A,Alg}
148+
avg.last_average = isnothing(avg.last_average) ? δas : sum(avg.weights[workers] * δas) / sum(avg.weights) + avg.last_average
149+
avg.algorithm(avg.last_average, problem)
150+
end
151+
152+
"""
153+
(::AveragingAlgorithm{Q,A,Alg})(q::Q, problem::Any) where {Q,A,Alg}
154+
155+
Steps performed by the workers when they receive a query `q::Q` from the central node
156+
"""
157+
function (avg::AveragingAlgorithm{Q,A,Alg})(q::Q, problem::Any) where {Q,A,Alg}
158+
a = avg.algorithm(q, problem)
159+
δa = isnothing(avg.last_answer) ? a : a - avg.last_answer
160+
avg.last_answer = a
161+
return δa
162+
end
163+
164+
stopnow(avg::AveragingAlgorithm, stopat::NamedTuple) = stopnow(avg.algorithm, stopat)
165+
showvalues(avg::AveragingAlgorithm) = showvalues(avg.algorithm)
166+
report(avg::AveragingAlgorithm) = report(avg.algorithm)
167+
progress(avg::AveragingAlgorithm, stopat::NamedTuple) = progress(avg.algorithm, stopat)
168+
savenow(avg::AveragingAlgorithm, saveat::NamedTuple) = savenow(avg.algorithm, saveat)
169+
savevalues(avg::AveragingAlgorithm) = savevalues(avg.algorithm)

0 commit comments

Comments
 (0)