diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index db7a35231..07d736d06 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -352,6 +352,13 @@ def get_config(self, gindex, pindex, group): if (gindex, pindex) in self.mng.index2config: config.update(self.mng.index2config[(gindex, pindex)]) + + # Also check pid2config as a fallback so that override_config works + # regardless of whether it was called before or after register_parameters. + p = self.param_groups[gindex]["params"][pindex] + if id(p) in self.mng.pid2config: + config.update(self.mng.pid2config[id(p)]) + return config def init_state(self, group, p, gindex, pindex): diff --git a/tests/test_optim.py b/tests/test_optim.py index 190d9a206..ae58a8249 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -302,6 +302,39 @@ def test_global_config(dim1, dim2, gtype, device): assert adam2.state[p3]["state2"].dtype == torch.uint8 +@pytest.mark.parametrize("device", get_available_devices(no_cpu=True)) +@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device") +def test_override_config_after_register(device): + """Test that override_config works when called after register_parameters (issue #1269).""" + if device not in ["cuda", "xpu"]: + pytest.skip("Optimizers are only supported on CUDA and XPU") + + mng = bnb.optim.GlobalOptimManager.get_instance() + mng.initialize() + + p1 = torch.randn(64, 64, device="cpu") * 0.1 + p2 = torch.randn(64, 64, device="cpu") * 0.1 + + # Register first, override second (the documented order) + mng.register_parameters([p1, p2]) + p1 = p1.to(device) + p2 = p2.to(device) + + # Override p2 to use 8-bit after register_parameters + mng.override_config(p2, "optim_bits", 8) + + adam = bnb.optim.Adam([p1, p2], lr=0.001, optim_bits=32) + + # Run a step to trigger init_state + p1.grad = torch.randn_like(p1) * 0.1 + p2.grad = torch.randn_like(p2) * 0.1 + adam.step() + + # p1 should be 32-bit, p2 should be 8-bit + assert adam.state[p1]["state1"].dtype == torch.float32 + assert adam.state[p2]["state1"].dtype == torch.uint8 + + optimizer_names_8bit = [ "adam8bit_blockwise", "lion8bit_blockwise",