-
Notifications
You must be signed in to change notification settings - Fork 41
Open
Description
the original __getitem__ function.
def __getitem__(self, batch_index: int) -> [torch.Tensor, int]:
image_dir, image_name = self.image_file_paths[batch_index].split(self.delimiter)[-2:]
# Read a batch of image data
if image_name.split(".")[-1].lower() in IMG_EXTENSIONS:
image = cv2.imread(self.image_file_paths[batch_index])
target = self.class_to_idx[image_dir]
else:
raise ValueError(f"Unsupported image extensions, Only support `{IMG_EXTENSIONS}`, "
"please check the image file extensions.")
# BGR to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# OpenCV convert PIL
image = Image.fromarray(image)
# Data preprocess
image = self.pre_transform(image)
# Convert image data into Tensor stream format (PyTorch).
# Note: The range of input and output is between [0, 1]
tensor = imgproc.image_to_tensor(image, False, False)
# Data postprocess
tensor = self.post_transform(tensor)
return {"image": tensor, "target": target}But according to pytorch doc,
the accurate form should be:
def __getitem__(self, batch_index: int) :
image_dir, image_name = self.image_file_paths[batch_index].split(self.delimiter)[-2:]
# Read a batch of image data
if image_name.split(".")[-1].lower() in IMG_EXTENSIONS:
image = cv2.imread(self.image_file_paths[batch_index])
target = self.class_to_idx[image_dir]
else:
raise ValueError(f"Unsupported image extensions, Only support `{IMG_EXTENSIONS}`, "
"please check the image file extensions.")
# BGR to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# OpenCV convert PIL
image = Image.fromarray(image)
# Data preprocess
image = self.pre_transform(image)
# Convert image data into Tensor stream format (PyTorch).
# Note: The range of input and output is between [0, 1]
tensor = imgproc.image_to_tensor(image, False, False)
# Data postprocess
tensor = self.post_transform(tensor)
return tensor, targetPS: I didn't run the code to experiment this idea.
Metadata
Metadata
Assignees
Labels
No labels