Skip to content

Commit 33ec6d7

Browse files
committed
Small improvement to map dataset retry logic
1 parent 8f7ab2e commit 33ec6d7

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

timm/data/dataset.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
_logger = logging.getLogger(__name__)
1616

1717

18-
_ERROR_RETRY = 50
18+
_ERROR_RETRY = 20
1919

2020

2121
class ImageDataset(data.Dataset):
@@ -48,21 +48,19 @@ def __init__(
4848
self.transform = transform
4949
self.target_transform = target_transform
5050
self.additional_features = additional_features
51-
self._consecutive_errors = 0
51+
self._max_retries = _ERROR_RETRY
5252

5353
def __getitem__(self, index):
54-
img, target, *features = self.reader[index]
55-
56-
try:
57-
img = img.read() if self.load_bytes else Image.open(img)
58-
except Exception as e:
59-
_logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
60-
self._consecutive_errors += 1
61-
if self._consecutive_errors < _ERROR_RETRY:
62-
return self.__getitem__((index + 1) % len(self.reader))
63-
else:
64-
raise e
65-
self._consecutive_errors = 0
54+
for attempt in range(self._max_retries):
55+
try:
56+
img, target, *features = self.reader[index]
57+
img = img.read() if self.load_bytes else Image.open(img)
58+
break
59+
except (IOError, OSError) as e: # be specific
60+
_logger.warning(f'Skipped sample (index {index}). {e}')
61+
index = (index + 1) % len(self.reader)
62+
else:
63+
raise RuntimeError(f"Failed to load {self._max_retries} consecutive samples")
6664

6765
if self.input_img_mode and not self.load_bytes:
6866
img = img.convert(self.input_img_mode)
@@ -134,7 +132,6 @@ def __init__(
134132
self.reader = reader
135133
self.transform = transform
136134
self.target_transform = target_transform
137-
self._consecutive_errors = 0
138135

139136
def __iter__(self):
140137
for img, target in self.reader:

0 commit comments

Comments
 (0)