Setting reduction = 'sum' does not work because of [this line](https://github.com/WeiChengTseng/Pytorch-PCGrad/blob/e987ac603fa1accd386820a985a6dc2fd92dec5b/pcgrad.py#L58): ``` if self._reduction: merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0) ``` because if reduction is a string, self._reduction is always True