@@ -61,92 +61,131 @@ abstract type Ordering end
6161struct Sorted <: Ordering end
6262struct Unsorted <: Ordering end
6363
64+ abstract type ReductionNode end
65+ struct TopTreeNode <: ReductionNode
66+ rank :: Int
67+ end
68+ struct SubTreeNode <: ReductionNode
69+ rank :: Int
70+ end
71+
6472function reducedvalue (freduce:: Function ,rank,
73+ pipe:: BranchChannel ,ifsorted:: Ordering )
74+
75+ reducedvalue (freduce,
76+ rank > 0 ? SubTreeNode (rank) : TopTreeNode (rank),
77+ pipe,ifsorted)
78+ end
79+
80+ function reducedvalue (freduce:: Function ,node:: SubTreeNode ,
6581 pipe:: BranchChannel{Tmap,Tred} ,:: Unsorted ) where {Tmap,Tred}
6682
83+ self = take! (pipe. selfchannels. out) :: Tmap
6784 N = nchildren (pipe)
68- if rank > 0
69- self = take! (pipe. selfchannels. out) :: Tmap
70- if N > 0
71- reducechildren = freduce (take! (pipe. childrenchannels. out):: Tred for i= 1 : N):: Tred
72- res = freduce ((reducechildren, self)) :: Tred
73- elseif N == 0
74- res = freduce ((self,)) :: Tred
75- end
76- else
77- if N > 0
78- res = freduce (take! (pipe. childrenchannels. out):: Tred for i= 1 : N):: Tred
79- elseif N == 0
80- # N == 0 && rank <= 0
81- # shouldn't reach this
82- error (" nodes with rank <=0 must have children" )
85+ vals = Vector {Tred} (undef,N+ 1 )
86+ @sync begin
87+ @async vals[1 ] = freduce ([self]) :: Tred
88+ @async for i= 1 : N
89+ vals[i+ 1 ] = take! (pipe. childrenchannels. out):: Tred
8390 end
8491 end
85- return res
92+
93+ freduce (vals)
8694end
95+ function reducedvalue (freduce:: Function ,node:: TopTreeNode ,
96+ pipe:: BranchChannel{<:Any,Tred} ,:: Unsorted ) where {Tred}
8797
88- function reducedvalue (freduce:: Function ,rank,
98+ N = nchildren (pipe)
99+ if N == 0
100+ # shouldn't reach this
101+ error (" Nodes on the top tree must have children" )
102+ end
103+ vals = Vector {Tred} (undef,N)
104+ for i= 1 : N
105+ vals[i] = take! (pipe. childrenchannels. out):: Tred
106+ end
107+
108+ freduce (vals)
109+ end
110+
111+ function reducedvalue (freduce:: Function ,node:: SubTreeNode ,
89112 pipe:: BranchChannel{Tmap,Tred} ,:: Sorted ) where {Tmap,Tred}
90113
114+ rank = node. rank
91115 N = nchildren (pipe)
92116 leftchild = N > 0
93- selfvalpresent = rank > 0
94- vals = Vector {Tred} (undef,N + selfvalpresent)
117+ vals = Vector {Tred} (undef,N + 1 )
95118 @sync begin
96119 @async begin
97- if selfvalpresent
98- selfval = take! (pipe. selfchannels. out):: Tmap
99- selfvalred = freduce ((value (selfval),))
100- pv = pval (rank,selfvalred)
101- ind = selfvalpresent + leftchild
102- vals[ind] = pv
103- end
120+ selfval = take! (pipe. selfchannels. out):: Tmap
121+ selfvalred = freduce ((value (selfval),))
122+ pv = pval (rank,selfvalred)
123+ ind = leftchild + 1
124+ vals[ind] = pv
104125 end
105- @async begin
106- if selfvalpresent
107- for i= 1 : N
108- pv = take! (pipe. childrenchannels. out) :: Tred
109- shift = pv. rank > rank ? 1 : - 1
110- ind = shift + leftchild + 1
111- vals[ind] = pv
112- end
113- else
114- for i= 1 : N
115- pv = take! (pipe. childrenchannels. out) :: Tred
116- vals[i] = pv
117- end
118- sort! (vals,by= pv-> pv. rank)
119- end
126+ @async for i= 1 : N
127+ pv = take! (pipe. childrenchannels. out) :: Tred
128+ shift = pv. rank > rank ? 1 : - 1
129+ ind = shift + leftchild + 1
130+ vals[ind] = pv
120131 end
121132 end
122133
123134 Tred (rank,freduce (value (v) for v in vals))
124135end
136+ function reducedvalue (freduce:: Function ,node:: TopTreeNode ,
137+ pipe:: BranchChannel{<:Any,Tred} ,:: Sorted ) where {Tred}
138+
139+ rank = node. rank
140+ N = nchildren (pipe)
141+ leftchild = N > 0
142+ @assert leftchild " Nodes on the top tree must have children"
143+ vals = Vector {Tred} (undef,N)
144+ for i= 1 : N
145+ pv = take! (pipe. childrenchannels. out) :: Tred
146+ vals[i] = pv
147+ end
148+ sort! (vals,by= pv-> pv. rank)
149+
150+ Tred (rank,freduce (value (v) for v in vals))
151+ end
125152
126153function indicatereduceprogress! (:: Nothing ,rank) end
127154function indicatereduceprogress! (progress:: RemoteChannel ,rank)
128155 put! (progress,(false ,true ,rank))
129156end
130157
131- function reduceTreeNode (freduce:: Function ,rank,pipe:: BranchChannel{Tmap,Tred} ,
132- ifsort:: Ordering ,progress:: Union{Nothing,RemoteChannel} ) where {Tmap,Tred}
158+ function reduceTreeNode (freduce:: Function ,rank,pipe:: BranchChannel ,
159+ ifsort:: Ordering ,progress)
160+
161+ reduceTreeNode (freduce,
162+ rank > 0 ? SubTreeNode (rank) : TopTreeNode (rank),
163+ pipe,ifsort,progress)
164+ end
165+
166+ function checkerror (:: SubTreeNode ,pipe:: BranchChannel )
167+ selferr = take! (pipe. selfchannels. err)
168+ childrenerr = any (take! (pipe. childrenchannels. err) for i= 1 : nchildren (pipe))
169+ selferr || childrenerr
170+ end
171+ function checkerror (:: TopTreeNode ,pipe:: BranchChannel )
172+ any (take! (pipe. childrenchannels. err) for i= 1 : nchildren (pipe))
173+ end
174+
175+ function reduceTreeNode (freduce:: Function ,node:: ReductionNode ,
176+ pipe:: BranchChannel{<:Any,Tred} ,
177+ ifsort:: Ordering ,progress:: Union{Nothing,RemoteChannel} ) where {Tred}
133178 # This function that communicates with the parent and children
134179
135180 # Start by checking if there is any error locally in the map,
136181 # and if there's none then check if there are any errors on the children
137- if rank > 0
138- anyerr = take! (pipe. selfchannels. err)
139- else
140- anyerr = false
141- end
142- anyerr = anyerr ||
143- any (take! (pipe. childrenchannels. err) for i= 1 : nchildren (pipe))
144-
182+ anyerr = checkerror (node,pipe)
183+ rank = node. rank
145184 # Evaluate the reduction only if there's no error
146185 # In either case push the error flag to the parent
147186 if ! anyerr
148187 try
149- res = reducedvalue (freduce,rank ,pipe,ifsort) :: Tred
188+ res = reducedvalue (freduce,node ,pipe,ifsort) :: Tred
150189 put! (pipe. parentchannels. out,res)
151190 put! (pipe. parentchannels. err,false )
152191 indicatereduceprogress! (progress,rank)
@@ -170,7 +209,7 @@ function return_unless_error(r::RemoteChannelContainer)
170209 end
171210end
172211
173- @inline function return_unless_error (b:: BranchChannel )
212+ function return_unless_error (b:: BranchChannel )
174213 return_unless_error (b. parentchannels)
175214end
176215
@@ -193,19 +232,22 @@ function pmapreduceworkers(fmap::Function,freduce::Function,iterators::Tuple,
193232
194233 for (ind,mypipe) in enumerate (branches)
195234 p = mypipe. p
196- rank = ind - extrareducenodes
197- if rank > 0
235+ ind_reduced = ind - extrareducenodes
236+ rank = ind_reduced
237+ if ind_reduced > 0
198238 iterable_on_proc = ProductSplit (iterators,num_workers_active,rank)
199239
200240 @spawnat p mapTreeNode (fmap,iterable_on_proc,rank,mypipe,
201- ifelse ( showprogress, progresschannel, nothing ) ,
241+ showprogress ? progresschannel : nothing ,
202242 args... ;kwargs... )
203243
204- @spawnat p reduceTreeNode (freduce,rank,mypipe,ord,
205- ifelse (showprogress,progresschannel,nothing ))
244+ @spawnat p reduceTreeNode (freduce,SubTreeNode (rank),
245+ mypipe,ord,
246+ showprogress ? progresschannel : nothing )
206247 else
207- @spawnat p reduceTreeNode (freduce,rank,mypipe,ord,
208- ifelse (showprogress,progresschannel,nothing ))
248+ @spawnat p reduceTreeNode (freduce,TopTreeNode (rank),
249+ mypipe,ord,
250+ showprogress ? progresschannel : nothing )
209251 end
210252 end
211253
0 commit comments