diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ba134f52a..e599643cc 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -356,6 +356,46 @@ def to(self, *args, **kwargs): return new_param + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func in [torch.chunk, torch.split]: + tensor = args[0] + + result = super().__torch_function__(func, types, args, kwargs) + + if isinstance(result, tuple): + return tuple( + cls( + data=chunk, + requires_grad=tensor.requires_grad, + quant_state=tensor.quant_state, + blocksize=tensor.blocksize, + compress_statistics=tensor.compress_statistics, + quant_type=tensor.quant_type, + quant_storage=tensor.quant_storage, + module=tensor.module, + bnb_quantized=tensor.bnb_quantized, + ) + for chunk in result + ) + else: + return cls( + data=result, + requires_grad=tensor.requires_grad, + quant_state=tensor.quant_state, + blocksize=tensor.blocksize, + compress_statistics=tensor.compress_statistics, + quant_type=tensor.quant_type, + quant_storage=tensor.quant_storage, + module=tensor.module, + bnb_quantized=tensor.bnb_quantized, + ) + + return super().__torch_function__(func, types, args, kwargs) + def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]): if getattr(module.weight, "quant_state", None) is not None: diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index e07b54d2d..1c5e77a32 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +def test_params4bit_torch_chunk_split(device, quant_type): + """Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility.""" + if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8): + pytest.skip("This configuration is not supported on HPU.") + + if device == "cpu": + pytest.skip("CPU quantization causes segfault, skipping CPU test") + + original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu") + + params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False) + + if device != "cpu": + params4bit = params4bit.to(device) + + chunks = torch.chunk(params4bit, 2, dim=0) + + assert isinstance(chunks, tuple), "torch.chunk should return tuple" + for chunk in chunks: + assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass" + assert hasattr(chunk, "quant_type"), "Should preserve metadata" + assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value" + + splits = torch.split(params4bit, 2, dim=0) + + assert isinstance(splits, tuple), "torch.split should return tuple" + assert len(splits) > 0, "Should have at least one split" + for split in splits: + assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass" + assert hasattr(split, "quant_type"), "Should preserve metadata" + assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value" + + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])