Skip to content

Commit 692d95f

Browse files
committed
Supporting materials for the PR4
1 parent aedebdd commit 692d95f

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#!/usr/bin/env python
2+
from functools import partial
3+
from itertools import tee
4+
5+
class Loader():
6+
def example_batch_generator(self,n):
7+
for batch in range(n):
8+
yield batch
9+
10+
class MPIModel():
11+
def __init__(self,batch_generator):
12+
self.batch_iterator = batch_generator
13+
14+
def train_epochs(self,M):
15+
num_total = 8
16+
for epoch in range(M):
17+
num_so_far = 0
18+
print ("Batch iter. summary: {}{}".format(self,self.batch_iterator))
19+
for batch in self.batch_iterator():
20+
num_so_far += 1
21+
22+
whatever=batch
23+
print ("Next batch id: {}".format(batch))
24+
if num_so_far > num_total: break
25+
print "+++++++"
26+
27+
28+
class MPIModel_default():
29+
def __init__(self,batch_generator):
30+
self.batch_iterator = batch_generator
31+
32+
def train_epochs(self,M):
33+
num_total = 8 #number of samples per epoch
34+
batch_generator_func = self.batch_iterator()
35+
36+
for iepoch in range(M):
37+
#print ("Batch iter. summary: {}{} epoch: {}".format(self,self.batch_iterator,iepoch))
38+
num_so_far = 0
39+
40+
while num_so_far < num_total:
41+
num_so_far += 1
42+
43+
try:
44+
batch = batch_generator_func.next()
45+
except:
46+
batch_generator_func = self.batch_iterator()
47+
batch = batch_generator_func.next()
48+
print ("Next batch id: {}".format(batch))
49+
50+
print "+++++++"
51+
52+
53+
54+
def main():
55+
num_batches = 10
56+
epochs = 3
57+
58+
loader = Loader()
59+
batch_generator = partial(loader.example_batch_generator,n=num_batches)
60+
my_example_class = MPIModel(batch_generator)
61+
my_example_class.train_epochs(epochs)
62+
63+
def main_default():
64+
num_batches = 10
65+
epochs = 3
66+
67+
loader = Loader()
68+
batch_generator = partial(loader.example_batch_generator,n=num_batches)
69+
my_example_class = MPIModel_default(batch_generator)
70+
my_example_class.train_epochs(epochs)
71+
72+
if __name__=='__main__':
73+
import timeit
74+
#print min(timeit.Timer(setup=main).repeat(7, 1000))
75+
print min(timeit.Timer(setup=main_default).repeat(7, 1000))

0 commit comments

Comments
 (0)