diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ea5451502..8fb61a7a6 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -290,6 +290,13 @@ def from_prequantized( return self + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + def _quantize(self, device): w = self.data.contiguous().to(device) w_4bit, quant_state = bnb.functional.quantize_4bit( @@ -486,7 +493,7 @@ def forward(self, x: torch.Tensor): bias = None if self.bias is None else self.bias.to(self.compute_dtype) - return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) + return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) class LinearFP4(Linear4bit):