Skip to content

Commit 9e99c8a

Browse files
committed
Revert commit c96641e - num_replicas parameter is enough
1 parent ac93ca9 commit 9e99c8a

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

plasma/models/mpi_runner.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get_val(self):
139139

140140

141141
class MPIModel():
142-
def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=None,warmup_steps=1000,lr=0.01,custom_num_workers=0):
142+
def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=None,warmup_steps=1000,lr=0.01):
143143
# random.seed(task_index)
144144
self.epoch = 0
145145
self.model = model
@@ -151,19 +151,13 @@ def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=No
151151
self.batch_size = batch_size
152152
self.batch_iterator = batch_iterator
153153
self.warmup_steps=warmup_steps
154-
if custom_num_workers:
155-
if custom_num_workers < comm.Get_size():
156-
self.num_workers = custom_num_workers
157-
else: self.num_workers = comm.Get_size()
158-
else:
159-
self.num_workers = comm.Get_size()
154+
self.num_workers = comm.Get_size()
160155
self.task_index = comm.Get_rank()
161156
self.history = cbks.History()
162157
if num_replicas is None or num_replicas < 1 or num_replicas > self.num_workers:
163-
self.num_replicas = self.num_workers
158+
self.num_replicas = self.num_workers
164159
else:
165-
self.num_replicas = num_replicas
166-
160+
self.num_replicas = num_replicas
167161

168162
def set_lr(self,lr):
169163
self.lr = lr

0 commit comments

Comments
 (0)