From b412c910fb083ce408c444b20e140d88c5a30c55 Mon Sep 17 00:00:00 2001 From: youngrok cha Date: Wed, 9 Apr 2025 17:19:52 +0900 Subject: [PATCH 1/2] [fix] define blocksize define blocksize as just showing number is a bit confusing --- bitsandbytes/optim/ademamix.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/optim/ademamix.py b/bitsandbytes/optim/ademamix.py index e52d96589..928289adb 100644 --- a/bitsandbytes/optim/ademamix.py +++ b/bitsandbytes/optim/ademamix.py @@ -166,8 +166,9 @@ def init_state(self, group, p, gindex, pindex): self.name2qmap["dynamic"] = state["qmap1"] = self.name2qmap["dynamic"].to(p.device) self.name2qmap["udynamic"] = state["qmap2"] = self.name2qmap["udynamic"].to(p.device) + blocksize = 256 n = p.numel() - blocks = (n // 256) + bool(n % 256) + blocks = (n // blocksize) + bool(n % blocksize) state["absmax1"] = torch.zeros((2, blocks), dtype=torch.float32, device=p.device) state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) From 24bec2a5cf4d8bfbe3a708f162c4382bf12a0c8e Mon Sep 17 00:00:00 2001 From: youngrok cha Date: Wed, 9 Apr 2025 17:22:13 +0900 Subject: [PATCH 2/2] [fix] match code with ademamix also define blocksize as in prev commit --- bitsandbytes/optim/optimizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index d4656efc4..4bed9a7c3 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -475,9 +475,9 @@ def init_state(self, group, p, gindex, pindex): state["qmap2"] = self.name2qmap["udynamic"] if config["block_wise"]: + blocksize = 256 n = p.numel() - blocks = n // 256 - blocks += 1 if n % 256 > 0 else 0 + blocks = (n // blocksize) + bool(n % blocksize) state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) @@ -697,9 +697,9 @@ def init_state(self, group, p, gindex, pindex): state["qmap1"] = self.name2qmap["dynamic"] if config["block_wise"]: + blocksize = 256 n = p.numel() - blocks = n // 256 - blocks += 1 if n % 256 > 0 else 0 + blocks = (n // blocksize) + bool(n % blocksize) state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) else: