-
Notifications
You must be signed in to change notification settings - Fork 48
Description
Thanks for the simple and elegant implementation!
I tried running your code as is, on Multi-MNIST data, and failed to reproduce results.
I ran main_multi_mnist.py without changing any hyper parameter (learning rate (0.0005) , batch size (256), number of epochs (100)). For comparison, I created a version with no pcgrad:
1. comment out line 57: ## optimizer = PCGrad(optimizer)
2. replace line 72: optimizer.pc_backward(losses) -> torch.sum(torch.stack(losses)).backward()
I run each version 7 times. my results (averaging left-digit and right-digit accuracy) are:
Without PCGrad: average accuracy 89.5%, max accuracy 89.9%, standard deviation 0.38
With PCGrad: average accuracy 89.5%, max accuracy 89.8%, standard deviation 0.20
Can you come up with an explanation?
Many thanks,
Noa Garnett