@@ -139,7 +139,7 @@ def get_val(self):
139139
140140
141141class MPIModel ():
142- def __init__ (self ,model ,optimizer ,comm ,batch_iterator ,batch_size ,num_replicas = None ,warmup_steps = 1000 ,lr = 0.01 , custom_num_workers = 0 ):
142+ def __init__ (self ,model ,optimizer ,comm ,batch_iterator ,batch_size ,num_replicas = None ,warmup_steps = 1000 ,lr = 0.01 ):
143143 # random.seed(task_index)
144144 self .epoch = 0
145145 self .model = model
@@ -151,19 +151,13 @@ def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=No
151151 self .batch_size = batch_size
152152 self .batch_iterator = batch_iterator
153153 self .warmup_steps = warmup_steps
154- if custom_num_workers :
155- if custom_num_workers < comm .Get_size ():
156- self .num_workers = custom_num_workers
157- else : self .num_workers = comm .Get_size ()
158- else :
159- self .num_workers = comm .Get_size ()
154+ self .num_workers = comm .Get_size ()
160155 self .task_index = comm .Get_rank ()
161156 self .history = cbks .History ()
162157 if num_replicas is None or num_replicas < 1 or num_replicas > self .num_workers :
163- self .num_replicas = self .num_workers
158+ self .num_replicas = self .num_workers
164159 else :
165- self .num_replicas = num_replicas
166-
160+ self .num_replicas = num_replicas
167161
168162 def set_lr (self ,lr ):
169163 self .lr = lr
0 commit comments