From 997510d34ce394af6e44acebaa6b07894fe799e4 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 16 Feb 2026 06:55:12 -0500 Subject: [PATCH] Fix GlobalOptimManager.override_config not propagating to optimizer (#1269) override_config only wrote to pid2config, but get_config only read from index2config. When override_config was called after register_parameters (the documented usage order), the override was never seen by the optimizer. Fix by having get_config also check pid2config as a fallback after index2config, so overrides work regardless of call order. Co-Authored-By: Claude Opus 4.6 --- bitsandbytes/optim/optimizer.py | 7 +++++++ tests/test_optim.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index db7a35231..07d736d06 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -352,6 +352,13 @@ def get_config(self, gindex, pindex, group): if (gindex, pindex) in self.mng.index2config: config.update(self.mng.index2config[(gindex, pindex)]) + + # Also check pid2config as a fallback so that override_config works + # regardless of whether it was called before or after register_parameters. + p = self.param_groups[gindex]["params"][pindex] + if id(p) in self.mng.pid2config: + config.update(self.mng.pid2config[id(p)]) + return config def init_state(self, group, p, gindex, pindex): diff --git a/tests/test_optim.py b/tests/test_optim.py index 190d9a206..ae58a8249 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -302,6 +302,39 @@ def test_global_config(dim1, dim2, gtype, device): assert adam2.state[p3]["state2"].dtype == torch.uint8 +@pytest.mark.parametrize("device", get_available_devices(no_cpu=True)) +@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device") +def test_override_config_after_register(device): + """Test that override_config works when called after register_parameters (issue #1269).""" + if device not in ["cuda", "xpu"]: + pytest.skip("Optimizers are only supported on CUDA and XPU") + + mng = bnb.optim.GlobalOptimManager.get_instance() + mng.initialize() + + p1 = torch.randn(64, 64, device="cpu") * 0.1 + p2 = torch.randn(64, 64, device="cpu") * 0.1 + + # Register first, override second (the documented order) + mng.register_parameters([p1, p2]) + p1 = p1.to(device) + p2 = p2.to(device) + + # Override p2 to use 8-bit after register_parameters + mng.override_config(p2, "optim_bits", 8) + + adam = bnb.optim.Adam([p1, p2], lr=0.001, optim_bits=32) + + # Run a step to trigger init_state + p1.grad = torch.randn_like(p1) * 0.1 + p2.grad = torch.randn_like(p2) * 0.1 + adam.step() + + # p1 should be 32-bit, p2 should be 8-bit + assert adam.state[p1]["state1"].dtype == torch.float32 + assert adam.state[p2]["state1"].dtype == torch.uint8 + + optimizer_names_8bit = [ "adam8bit_blockwise", "lion8bit_blockwise",