diff --git a/bitsandbytes/optim/ademamix.py b/bitsandbytes/optim/ademamix.py index e52d96589..928289adb 100644 --- a/bitsandbytes/optim/ademamix.py +++ b/bitsandbytes/optim/ademamix.py @@ -166,8 +166,9 @@ def init_state(self, group, p, gindex, pindex): self.name2qmap["dynamic"] = state["qmap1"] = self.name2qmap["dynamic"].to(p.device) self.name2qmap["udynamic"] = state["qmap2"] = self.name2qmap["udynamic"].to(p.device) + blocksize = 256 n = p.numel() - blocks = (n // 256) + bool(n % 256) + blocks = (n // blocksize) + bool(n % blocksize) state["absmax1"] = torch.zeros((2, blocks), dtype=torch.float32, device=p.device) state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index d4656efc4..4bed9a7c3 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -475,9 +475,9 @@ def init_state(self, group, p, gindex, pindex): state["qmap2"] = self.name2qmap["udynamic"] if config["block_wise"]: + blocksize = 256 n = p.numel() - blocks = n // 256 - blocks += 1 if n % 256 > 0 else 0 + blocks = (n // blocksize) + bool(n % blocksize) state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) @@ -697,9 +697,9 @@ def init_state(self, group, p, gindex, pindex): state["qmap1"] = self.name2qmap["dynamic"] if config["block_wise"]: + blocksize = 256 n = p.numel() - blocks = n // 256 - blocks += 1 if n % 256 > 0 else 0 + blocks = (n // blocksize) + bool(n % blocksize) state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) else: