Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,13 @@ def get_config(self, gindex, pindex, group):

if (gindex, pindex) in self.mng.index2config:
config.update(self.mng.index2config[(gindex, pindex)])

# Also check pid2config as a fallback so that override_config works
# regardless of whether it was called before or after register_parameters.
p = self.param_groups[gindex]["params"][pindex]
if id(p) in self.mng.pid2config:
config.update(self.mng.pid2config[id(p)])

return config

def init_state(self, group, p, gindex, pindex):
Expand Down
33 changes: 33 additions & 0 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,39 @@ def test_global_config(dim1, dim2, gtype, device):
assert adam2.state[p3]["state2"].dtype == torch.uint8


@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
def test_override_config_after_register(device):
"""Test that override_config works when called after register_parameters (issue #1269)."""
if device not in ["cuda", "xpu"]:
pytest.skip("Optimizers are only supported on CUDA and XPU")

mng = bnb.optim.GlobalOptimManager.get_instance()
mng.initialize()

p1 = torch.randn(64, 64, device="cpu") * 0.1
p2 = torch.randn(64, 64, device="cpu") * 0.1

# Register first, override second (the documented order)
mng.register_parameters([p1, p2])
p1 = p1.to(device)
p2 = p2.to(device)

# Override p2 to use 8-bit after register_parameters
mng.override_config(p2, "optim_bits", 8)

adam = bnb.optim.Adam([p1, p2], lr=0.001, optim_bits=32)

# Run a step to trigger init_state
p1.grad = torch.randn_like(p1) * 0.1
p2.grad = torch.randn_like(p2) * 0.1
adam.step()

# p1 should be 32-bit, p2 should be 8-bit
assert adam.state[p1]["state1"].dtype == torch.float32
assert adam.state[p2]["state1"].dtype == torch.uint8


optimizer_names_8bit = [
"adam8bit_blockwise",
"lion8bit_blockwise",
Expand Down