Skip to content

Commit 84d47ac

Browse files
committed
follow comments
1 parent 9d2f49c commit 84d47ac

File tree

1 file changed

+22
-33
lines changed

1 file changed

+22
-33
lines changed

python/paddle/utils/image_multiproc.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from PIL import Image
44
from cStringIO import StringIO
55
import multiprocessing
6-
from functools import partial
6+
import functools
7+
import itertools
78

89
from paddle.utils.image_util import *
910
from paddle.trainer.config_parser import logger
@@ -14,10 +15,12 @@
1415
logger.warning("OpenCV2 is not installed, using PIL to prcoess")
1516
cv2 = None
1617

18+
__all__ = ["CvTransformer", "PILTransformer", "MultiProcessImageTransformer"]
1719

18-
class CvTransfomer(ImageTransformer):
20+
21+
class CvTransformer(ImageTransformer):
1922
"""
20-
CvTransfomer used python-opencv to process image.
23+
CvTransformer used python-opencv to process image.
2124
"""
2225

2326
def __init__(
@@ -97,9 +100,9 @@ def transform_from_file(self, file):
97100
return self.transform(im)
98101

99102

100-
class PILTransfomer(ImageTransformer):
103+
class PILTransformer(ImageTransformer):
101104
"""
102-
PILTransfomer used PIL to process image.
105+
PILTransformer used PIL to process image.
103106
"""
104107

105108
def __init__(
@@ -170,8 +173,11 @@ def transform_from_file(self, file):
170173
return self.transform(im)
171174

172175

173-
def warpper(cls, (dat, label)):
174-
return cls.job(dat, label)
176+
def job(is_img_string, transformer, (data, label)):
177+
if is_img_string:
178+
return transformer.transform_from_string(data), label
179+
else:
180+
return transformer.transform_from_file(data), label
175181

176182

177183
class MultiProcessImageTransformer(object):
@@ -238,36 +244,19 @@ def process(settings, file_list):
238244
:type is_img_string: bool.
239245
"""
240246

247+
self.procnum = procnum
241248
self.pool = multiprocessing.Pool(procnum)
242249
self.is_img_string = is_img_string
243250
if cv2 is not None:
244-
self.transformer = CvTransfomer(resize_size, crop_size, transpose,
245-
channel_swap, mean, is_train,
246-
is_color)
247-
else:
248-
self.transformer = PILTransfomer(resize_size, crop_size, transpose,
251+
self.transformer = CvTransformer(resize_size, crop_size, transpose,
249252
channel_swap, mean, is_train,
250253
is_color)
251-
252-
def run(self, data, label):
253-
try:
254-
fun = partial(warpper, self)
255-
return self.pool.imap_unordered(fun, zip(data, label), chunksize=5)
256-
except KeyboardInterrupt:
257-
self.pool.terminate()
258-
except Exception, e:
259-
self.pool.terminate()
260-
261-
def job(self, data, label):
262-
if self.is_img_string:
263-
return self.transformer.transform_from_string(data), label
264254
else:
265-
return self.transformer.transform_from_file(data), label
266-
267-
def __getstate__(self):
268-
self_dict = self.__dict__.copy()
269-
del self_dict['pool']
270-
return self_dict
255+
self.transformer = PILTransformer(resize_size, crop_size, transpose,
256+
channel_swap, mean, is_train,
257+
is_color)
271258

272-
def __setstate__(self, state):
273-
self.__dict__.update(state)
259+
def run(self, data, label):
260+
fun = functools.partial(job, self.is_img_string, self.transformer)
261+
return self.pool.imap_unordered(
262+
fun, itertools.izip(data, label), chunksize=100 * self.procnum)

0 commit comments

Comments
 (0)