|
15 | 15 | _logger = logging.getLogger(__name__) |
16 | 16 |
|
17 | 17 |
|
18 | | -_ERROR_RETRY = 50 |
| 18 | +_ERROR_RETRY = 20 |
19 | 19 |
|
20 | 20 |
|
21 | 21 | class ImageDataset(data.Dataset): |
@@ -48,21 +48,19 @@ def __init__( |
48 | 48 | self.transform = transform |
49 | 49 | self.target_transform = target_transform |
50 | 50 | self.additional_features = additional_features |
51 | | - self._consecutive_errors = 0 |
| 51 | + self._max_retries = _ERROR_RETRY |
52 | 52 |
|
53 | 53 | def __getitem__(self, index): |
54 | | - img, target, *features = self.reader[index] |
55 | | - |
56 | | - try: |
57 | | - img = img.read() if self.load_bytes else Image.open(img) |
58 | | - except Exception as e: |
59 | | - _logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}') |
60 | | - self._consecutive_errors += 1 |
61 | | - if self._consecutive_errors < _ERROR_RETRY: |
62 | | - return self.__getitem__((index + 1) % len(self.reader)) |
63 | | - else: |
64 | | - raise e |
65 | | - self._consecutive_errors = 0 |
| 54 | + for attempt in range(self._max_retries): |
| 55 | + try: |
| 56 | + img, target, *features = self.reader[index] |
| 57 | + img = img.read() if self.load_bytes else Image.open(img) |
| 58 | + break |
| 59 | + except (IOError, OSError) as e: # be specific |
| 60 | + _logger.warning(f'Skipped sample (index {index}). {e}') |
| 61 | + index = (index + 1) % len(self.reader) |
| 62 | + else: |
| 63 | + raise RuntimeError(f"Failed to load {self._max_retries} consecutive samples") |
66 | 64 |
|
67 | 65 | if self.input_img_mode and not self.load_bytes: |
68 | 66 | img = img.convert(self.input_img_mode) |
@@ -134,7 +132,6 @@ def __init__( |
134 | 132 | self.reader = reader |
135 | 133 | self.transform = transform |
136 | 134 | self.target_transform = target_transform |
137 | | - self._consecutive_errors = 0 |
138 | 135 |
|
139 | 136 | def __iter__(self): |
140 | 137 | for img, target in self.reader: |
|
0 commit comments