@@ -577,6 +577,8 @@ function pmapreduce_commutative(fmap::Function,freduce::Function,iterators::Tupl
577577
578578 eachnode_reducechannel = node_channels[node]. out
579579 eachnode_errorchannel = node_channels[node]. err
580+
581+ p_parent = eachnode_reducechannel. where
580582
581583 np_node = nprocs_node_dict[node]
582584
@@ -591,14 +593,33 @@ function pmapreduce_commutative(fmap::Function,freduce::Function,iterators::Tupl
591593 put! (eachnode_errorchannel,true )
592594 rethrow ()
593595 finally
594- if p ∉ procid_rank1_on_node
596+ # Don't need references to the intermediate reduction channels on other nodes
597+ for (node_i,channels_i) in node_channels
598+ if node_i == node
599+ continue
600+ end
601+
602+ finalize (channels_i. out)
603+ finalize (channels_i. err)
604+ end
605+ if p != p_parent
595606 finalize (eachnode_errorchannel)
596607 finalize (eachnode_reducechannel)
608+ finalize (finalnode_errorchannel)
609+ finalize (finalnode_reducechannel)
610+ end
611+ if p != p_final
612+ if p != result_channel. where
613+ finalize (result_channel)
614+ end
615+ if p != error_channel. where
616+ finalize (error_channel)
617+ end
597618 end
598619 end
599620 end
600621
601- @async if p in procid_rank1_on_node
622+ @async if p == p_parent
602623 @spawnat p begin
603624 try
604625 anyerror = any (take! (eachnode_errorchannel) for i= 1 : np_node)
@@ -613,7 +634,9 @@ function pmapreduce_commutative(fmap::Function,freduce::Function,iterators::Tupl
613634 put! (finalnode_errorchannel,true )
614635 rethrow ()
615636 finally
637+ finalize (eachnode_errorchannel)
616638 finalize (eachnode_reducechannel)
639+
617640 if p != p_final
618641 finalize (finalnode_errorchannel)
619642 finalize (finalnode_reducechannel)
@@ -700,27 +723,29 @@ function pmapreduce(fmap::Function,freduce::Function,iterable::Tuple,args...;kwa
700723 nprocs_node_dict = nprocs_node (procs_used)
701724 node_channels = Dict (
702725 node=> (
703- out = RemoteChannel (()-> Channel {Any } (nprocs_node_dict[node]),procid_node),
726+ out = RemoteChannel (()-> Channel {pval } (nprocs_node_dict[node]),procid_node),
704727 err = RemoteChannel (()-> Channel {Bool} (nprocs_node_dict[node]),procid_node),
705728 )
706729 for (node,procid_node) in zip (nodes,procid_rank1_on_node))
707730
708731 # Worker at which the final reduction takes place
709732 p_final = first (procid_rank1_on_node)
710733
711- finalnode_reducechannel = RemoteChannel (()-> Channel {pval} (length (procid_rank1_on_node) ),p_final)
712- finalnode_errorchannel = RemoteChannel (()-> Channel {Bool} (length (procid_rank1_on_node) ),p_final)
734+ finalnode_reducechannel = RemoteChannel (()-> Channel {pval} (Nnodes_reduction ),p_final)
735+ finalnode_errorchannel = RemoteChannel (()-> Channel {Bool} (Nnodes_reduction ),p_final)
713736
714737 result_channel = RemoteChannel (()-> Channel {Any} (1 ))
715738 error_channel = RemoteChannel (()-> Channel {Bool} (1 ))
716739
717- # Run the function on each processor and compute the sum at each node
740+ # Run the function on each processor and compute the reduction at each node
718741 @sync for (rank,(p,node)) in enumerate (zip (procs_used,hostnames))
719742 @async begin
720743
721744 eachnode_reducechannel = node_channels[node]. out
722745 eachnode_errorchannel = node_channels[node]. err
723746
747+ p_parent = eachnode_reducechannel. where
748+
724749 np_node = nprocs_node_dict[node]
725750
726751 iterable_on_proc = evenlyscatterproduct (iterable,num_workers,rank)
@@ -733,14 +758,33 @@ function pmapreduce(fmap::Function,freduce::Function,iterable::Tuple,args...;kwa
733758 put! (eachnode_errorchannel,true )
734759 rethrow ()
735760 finally
736- if p ∉ procid_rank1_on_node
761+ # Don't need references to the intermediate reduction channels on other nodes
762+ for (node_i,channels_i) in node_channels
763+ if node_i == node
764+ continue
765+ end
766+
767+ finalize (channels_i. out)
768+ finalize (channels_i. err)
769+ end
770+ if p != p_parent
737771 finalize (eachnode_errorchannel)
738772 finalize (eachnode_reducechannel)
773+ finalize (finalnode_errorchannel)
774+ finalize (finalnode_reducechannel)
739775 end
740- end
776+ if p != p_final
777+ if p != result_channel. where
778+ finalize (result_channel)
779+ end
780+ if p != error_channel. where
781+ finalize (error_channel)
782+ end
783+ end
784+ end
741785 end
742786
743- @async if p in procid_rank1_on_node
787+ @async if p == p_parent
744788 @spawnat p begin
745789 try
746790 anyerror = any (take! (eachnode_errorchannel) for i= 1 : np_node)
@@ -759,6 +803,7 @@ function pmapreduce(fmap::Function,freduce::Function,iterable::Tuple,args...;kwa
759803 finally
760804 finalize (eachnode_errorchannel)
761805 finalize (eachnode_reducechannel)
806+
762807 if p != p_final
763808 finalize (finalnode_errorchannel)
764809 finalize (finalnode_reducechannel)
@@ -786,6 +831,7 @@ function pmapreduce(fmap::Function,freduce::Function,iterable::Tuple,args...;kwa
786831 finally
787832 finalize (finalnode_errorchannel)
788833 finalize (finalnode_reducechannel)
834+
789835 if p != result_channel. where
790836 finalize (result_channel)
791837 end
0 commit comments