|
3 | 3 | from PIL import Image |
4 | 4 | from cStringIO import StringIO |
5 | 5 | import multiprocessing |
6 | | -from functools import partial |
| 6 | +import functools |
| 7 | +import itertools |
7 | 8 |
|
8 | 9 | from paddle.utils.image_util import * |
9 | 10 | from paddle.trainer.config_parser import logger |
|
14 | 15 | logger.warning("OpenCV2 is not installed, using PIL to prcoess") |
15 | 16 | cv2 = None |
16 | 17 |
|
| 18 | +__all__ = ["CvTransformer", "PILTransformer", "MultiProcessImageTransformer"] |
17 | 19 |
|
18 | | -class CvTransfomer(ImageTransformer): |
| 20 | + |
| 21 | +class CvTransformer(ImageTransformer): |
19 | 22 | """ |
20 | | - CvTransfomer used python-opencv to process image. |
| 23 | + CvTransformer used python-opencv to process image. |
21 | 24 | """ |
22 | 25 |
|
23 | 26 | def __init__( |
@@ -97,9 +100,9 @@ def transform_from_file(self, file): |
97 | 100 | return self.transform(im) |
98 | 101 |
|
99 | 102 |
|
100 | | -class PILTransfomer(ImageTransformer): |
| 103 | +class PILTransformer(ImageTransformer): |
101 | 104 | """ |
102 | | - PILTransfomer used PIL to process image. |
| 105 | + PILTransformer used PIL to process image. |
103 | 106 | """ |
104 | 107 |
|
105 | 108 | def __init__( |
@@ -170,8 +173,11 @@ def transform_from_file(self, file): |
170 | 173 | return self.transform(im) |
171 | 174 |
|
172 | 175 |
|
173 | | -def warpper(cls, (dat, label)): |
174 | | - return cls.job(dat, label) |
| 176 | +def job(is_img_string, transformer, (data, label)): |
| 177 | + if is_img_string: |
| 178 | + return transformer.transform_from_string(data), label |
| 179 | + else: |
| 180 | + return transformer.transform_from_file(data), label |
175 | 181 |
|
176 | 182 |
|
177 | 183 | class MultiProcessImageTransformer(object): |
@@ -238,36 +244,19 @@ def process(settings, file_list): |
238 | 244 | :type is_img_string: bool. |
239 | 245 | """ |
240 | 246 |
|
| 247 | + self.procnum = procnum |
241 | 248 | self.pool = multiprocessing.Pool(procnum) |
242 | 249 | self.is_img_string = is_img_string |
243 | 250 | if cv2 is not None: |
244 | | - self.transformer = CvTransfomer(resize_size, crop_size, transpose, |
245 | | - channel_swap, mean, is_train, |
246 | | - is_color) |
247 | | - else: |
248 | | - self.transformer = PILTransfomer(resize_size, crop_size, transpose, |
| 251 | + self.transformer = CvTransformer(resize_size, crop_size, transpose, |
249 | 252 | channel_swap, mean, is_train, |
250 | 253 | is_color) |
251 | | - |
252 | | - def run(self, data, label): |
253 | | - try: |
254 | | - fun = partial(warpper, self) |
255 | | - return self.pool.imap_unordered(fun, zip(data, label), chunksize=5) |
256 | | - except KeyboardInterrupt: |
257 | | - self.pool.terminate() |
258 | | - except Exception, e: |
259 | | - self.pool.terminate() |
260 | | - |
261 | | - def job(self, data, label): |
262 | | - if self.is_img_string: |
263 | | - return self.transformer.transform_from_string(data), label |
264 | 254 | else: |
265 | | - return self.transformer.transform_from_file(data), label |
266 | | - |
267 | | - def __getstate__(self): |
268 | | - self_dict = self.__dict__.copy() |
269 | | - del self_dict['pool'] |
270 | | - return self_dict |
| 255 | + self.transformer = PILTransformer(resize_size, crop_size, transpose, |
| 256 | + channel_swap, mean, is_train, |
| 257 | + is_color) |
271 | 258 |
|
272 | | - def __setstate__(self, state): |
273 | | - self.__dict__.update(state) |
| 259 | + def run(self, data, label): |
| 260 | + fun = functools.partial(job, self.is_img_string, self.transformer) |
| 261 | + return self.pool.imap_unordered( |
| 262 | + fun, itertools.izip(data, label), chunksize=100 * self.procnum) |
0 commit comments