diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ccd842ce3..e349cc843 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -354,6 +354,7 @@ def to(self, *args, **kwargs): compress_statistics=self.compress_statistics, quant_type=self.quant_type, quant_storage=self.quant_storage, + bnb_quantized=self.bnb_quantized, ) return new_param