diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index fe687e1e8..d42facea8 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -487,6 +487,18 @@ def __init__( self.state2 = state2 self.nested = state2 is not None + def __getattr__(self, name): + # Support attribute access for packed state_dict keys like "bitsandbytes__nf4". + # PyTorch's FSDP state_dict traversal (_get_fqns) resolves dotted FQN paths via + # getattr. The packed key "quant_state.bitsandbytes__nf4" causes it to call + # getattr(quant_state_obj, "bitsandbytes__nf4"), which we handle here. + if name.startswith("bitsandbytes__"): + qs_dict = self.as_dict(packed=True) + packed_key = "quant_state." + name + if packed_key in qs_dict: + return qs_dict[packed_key] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + def __getitem__(self, idx): """ ensures compatibility with older quant state scheme with nested lists. diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 67847f40c..b0e851010 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -256,6 +256,43 @@ def __setstate__(self, state): self.bnb_quantized = state["bnb_quantized"] self.module = state["module"] + # Map from state_dict key names (as produced by QuantState.as_dict) to + # the actual QuantState attribute/access path. FSDP's _get_fqns() resolves + # dotted FQN keys via getattr, so "weight.quant_map" becomes + # getattr(weight, "quant_map") — we must map that to quant_state.code. + _QUANT_STATE_ATTR_MAP = { + # Direct QuantState attributes + "absmax": lambda qs: qs.absmax, + "code": lambda qs: qs.code, + "blocksize": lambda qs: qs.blocksize, + "dtype": lambda qs: qs.dtype, + "shape": lambda qs: qs.shape, + "offset": lambda qs: qs.offset, + "state2": lambda qs: qs.state2, + # as_dict serializes code → "quant_map" + "quant_map": lambda qs: qs.code, + "quant_type": lambda qs: qs.quant_type, + # as_dict serializes nested state2 attributes under "nested_*" keys + "nested_absmax": lambda qs: qs.state2.absmax, + "nested_blocksize": lambda qs: qs.state2.blocksize, + "nested_quant_map": lambda qs: qs.state2.code, + "nested_dtype": lambda qs: qs.state2.dtype, + "nested_offset": lambda qs: qs.offset, + } + + def __getattr__(self, name): + # Proxy known QuantState attributes so that PyTorch's FSDP state_dict + # machinery (which traverses FQN paths via getattr) can find them. + accessor = self._QUANT_STATE_ATTR_MAP.get(name) + if accessor is not None: + quant_state = self.__dict__.get("quant_state") + if quant_state is not None: + try: + return accessor(quant_state) + except AttributeError: + pass + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + def __deepcopy__(self, memo): new_instance = type(self).__new__(type(self)) state = self.__getstate__() diff --git a/tests/fsdp_state_dict_save.py b/tests/fsdp_state_dict_save.py new file mode 100644 index 000000000..2e56c1c03 --- /dev/null +++ b/tests/fsdp_state_dict_save.py @@ -0,0 +1,80 @@ +"""FSDP state_dict save integration test for 4-bit quantized models (#1405). + +This script must be launched via torchrun (not directly): + torchrun --nproc_per_node=1 tests/fsdp_state_dict_save.py + +It wraps a QLoRA-style model (frozen 4-bit base + trainable adapter) in FSDP +and calls get_model_state_dict with cpu_offload=True, which exercises the +_get_fqns() getattr traversal that previously crashed with: + AttributeError: 'Params4bit' object has no attribute 'absmax' +""" + +import sys + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +import torch.nn as nn + +import bitsandbytes as bnb + + +class SimpleQLoRAModel(nn.Module): + """Minimal model with a frozen 4-bit base layer and a trainable adapter.""" + + def __init__(self, quant_type="nf4"): + super().__init__() + self.base = bnb.nn.Linear4bit(64, 64, bias=False, quant_type=quant_type) + self.adapter = nn.Linear(64, 64, bias=False) + + def forward(self, x): + return self.base(x) + self.adapter(x) + + +def main(): + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + torch.cuda.set_device(rank) + + errors = [] + + for quant_type in ("nf4", "fp4"): + model = SimpleQLoRAModel(quant_type=quant_type) + model = model.to("cuda") + + # Freeze quantized base weights (as in real QLoRA) + for p in model.base.parameters(): + p.requires_grad = False + + # Tell FSDP to ignore the frozen quantized params (can't flatten int dtypes) + ignored = list(model.base.parameters()) + fsdp_model = FSDP(model, device_id=rank, ignored_states=ignored, use_orig_params=True) + + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + try: + state_dict = get_model_state_dict(fsdp_model, options=options) + + # Verify expected keys are present + expected_substrings = ["base.weight", "absmax", "quant_map", "adapter.weight"] + for substr in expected_substrings: + if not any(substr in k for k in state_dict.keys()): + errors.append(f"{quant_type}: missing key containing '{substr}' in {list(state_dict.keys())}") + + print(f"{quant_type}: SUCCESS ({len(state_dict)} keys)", flush=True) + except Exception as e: + errors.append(f"{quant_type}: {type(e).__name__}: {e}") + print(f"{quant_type}: FAILED: {e}", flush=True) + + dist.destroy_process_group() + + if errors: + print("\nFAILURES:\n" + "\n".join(errors), file=sys.stderr, flush=True) + sys.exit(1) + else: + print("\nAll FSDP state_dict tests passed.", flush=True) + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index de40d158c..1585ea389 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,7 +1,9 @@ import copy import os +import pathlib import pickle import platform +import subprocess import sys from tempfile import TemporaryDirectory @@ -431,3 +433,96 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st grad_compiled = x.grad.clone() torch.testing.assert_close(grad_compiled, grad_ref) + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +def test_params4bit_quant_state_attr_access(device, quant_type, compress_statistics): + """Test that Params4bit proxies QuantState attributes for FSDP state_dict traversal (#1405). + + PyTorch's FSDP state_dict machinery traverses FQN paths like + 'model.layers.0.weight.absmax' using getattr(). This test verifies + that Params4bit and QuantState expose the attributes that appear as + state_dict keys so that _get_fqns() traversal succeeds. + """ + if device == "hpu" and not is_supported_on_hpu(quant_type): + pytest.skip("This configuration is not supported on HPU.") + + layer = bnb.nn.Linear4bit( + 64, + 64, + bias=False, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + layer = layer.to(device) + w = layer.weight + + assert w.quant_state is not None, "quant_state should be set after quantization" + + # Direct QuantState attributes proxied through Params4bit + assert torch.equal(w.absmax, w.quant_state.absmax) + assert torch.equal(w.code, w.quant_state.code) + + # "quant_map" is how as_dict() serializes "code" — FSDP uses this key name + assert torch.equal(w.quant_map, w.quant_state.code) + + # QuantState packed key: as_dict(packed=True) produces "quant_state.bitsandbytes__" + # FSDP resolves this as getattr(quant_state_obj, "bitsandbytes__") + packed_attr = f"bitsandbytes__{quant_type}" + assert hasattr(w.quant_state, packed_attr) + packed_val = getattr(w.quant_state, packed_attr) + assert isinstance(packed_val, torch.Tensor) + + # Simulate the full FSDP _get_fqns traversal for all state_dict keys + state_dict_keys = list(w.quant_state.as_dict(packed=True).keys()) + for key in state_dict_keys: + # Each key is relative to "weight.", e.g. "absmax" or "quant_state.bitsandbytes__nf4" + parts = key.split(".") + obj = w + for part in parts: + obj = getattr(obj, part) + assert obj is not None + + # hasattr should return True for proxied attrs, False for unknown ones + assert hasattr(w, "absmax") + assert hasattr(w, "code") + assert hasattr(w, "quant_map") + assert not hasattr(w, "nonexistent_attribute") + + # Unknown attributes must still raise AttributeError + with pytest.raises(AttributeError, match="nonexistent_attribute"): + _ = w.nonexistent_attribute + + # Verify that normal Params4bit attributes are unaffected by __getattr__ + assert isinstance(w.quant_state, bnb.functional.QuantState) + assert isinstance(w.bnb_quantized, bool) + assert w.bnb_quantized is True + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="FSDP requires CUDA") +@pytest.mark.skipif( + not torch.distributed.is_nccl_available(), + reason="FSDP test requires NCCL backend", +) +def test_fsdp_state_dict_save_4bit(): + """Integration test: FSDP get_model_state_dict with cpu_offload on a 4-bit model (#1405). + + Launches a single-GPU FSDP process via torchrun to exercise the real + _get_fqns() code path that previously crashed with: + AttributeError: 'Params4bit' object has no attribute 'absmax' + """ + script = pathlib.Path(__file__).with_name("fsdp_state_dict_save.py") + result = subprocess.run( + ["torchrun", "--nproc_per_node=1", str(script)], + capture_output=True, + text=True, + timeout=120, + ) + if result.returncode != 0: + pytest.fail( + f"FSDP state_dict test failed (exit {result.returncode}):\n" + f"stdout: {result.stdout}\n" + f"stderr: {result.stderr}" + )