-
Notifications
You must be signed in to change notification settings - Fork 666
Open
Labels
bugSomething isn't workingSomething isn't working
Description
I'm trying to use the MetricLossOnly trainer with the HierarchicalSampler, but it's inheriting from Sampler instead of from BatchSampler and due to this, the following block isn't executed.
pytorch-metric-learning/src/pytorch_metric_learning/utils/common_functions.py
Lines 176 to 183 in c835099
| if isinstance(sampler, torch.utils.data.BatchSampler): | |
| return torch.utils.data.DataLoader( | |
| dataset, | |
| batch_sampler=sampler, | |
| num_workers=num_workers, | |
| collate_fn=collate_fn, | |
| pin_memory=False, | |
| ) |
And the training doesn't work because instead of getting a 4D tensor like b,c,h,w, I get b,b,c,h,w and an error.
I haven't digged further, but so far, the training seems to be working if I force the HierarchicalSampler to inherit from BatchSampler.
Why did you change it from BatchSampler to Sampler?
Thanks for this fantastic library!
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working