Skip to content

Commit da96dde

Browse files
Update params4bit __torch_function__
1 parent e25e0ab commit da96dde

File tree

1 file changed

+73
-66
lines changed

1 file changed

+73
-66
lines changed

bitsandbytes/nn/modules.py

Lines changed: 73 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -223,18 +223,18 @@ def __new__(
223223
data = torch.empty(0)
224224

225225
# Handle FakeTensor creation during dynamo tracing
226-
if torch._dynamo.is_compiling() and not isinstance(data, cls):
227-
if isinstance(data, torch._subclasses.FakeTensor):
228-
param = data.as_subclass(cls)
229-
param.requires_grad = requires_grad
230-
param.quant_state = quant_state
231-
param.blocksize = blocksize
232-
param.compress_statistics = compress_statistics
233-
param.quant_type = quant_type
234-
param.quant_storage = quant_storage
235-
param.module = module
236-
param.bnb_quantized = bnb_quantized
237-
return param
226+
# if torch._dynamo.is_compiling() and not isinstance(data, cls):
227+
# if isinstance(data, torch._subclasses.FakeTensor):
228+
# param = data.as_subclass(cls)
229+
# param.requires_grad = requires_grad
230+
# param.quant_state = quant_state
231+
# param.blocksize = blocksize
232+
# param.compress_statistics = compress_statistics
233+
# param.quant_type = quant_type
234+
# param.quant_storage = quant_storage
235+
# param.module = module
236+
# param.bnb_quantized = bnb_quantized
237+
# return param
238238

239239
# Standard initialization for real tensors
240240
self = torch.Tensor._make_subclass(cls, data, requires_grad)
@@ -356,63 +356,70 @@ def to(self, *args, **kwargs):
356356
bnb_quantized=self.bnb_quantized,
357357
)
358358

359-
def __tensor_flatten__(self):
360-
"""Return data tensor and non-tensor context"""
361-
ctx = {
362-
"quant_state": self.quant_state,
363-
"blocksize": self.blocksize,
364-
"compress_statistics": self.compress_statistics,
365-
"quant_type": self.quant_type,
366-
"quant_storage": self.quant_storage,
367-
"module": self.module,
368-
"bnb_quantized": self.bnb_quantized,
369-
}
370-
return ["data"], ctx
371-
372-
@staticmethod
373-
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
374-
"""Reconstruct Params4bit from components"""
375-
data = inner_tensors["data"]
376-
377-
# Special handling for FakeTensor reconstruction
378-
if isinstance(data, torch._subclasses.FakeTensor):
379-
param = data.as_subclass(Params4bit)
380-
param.blocksize = ctx["blocksize"]
381-
param.compress_statistics = ctx["compress_statistics"]
382-
param.quant_type = ctx["quant_type"]
383-
param.quant_state = ctx["quant_state"]
384-
param.quant_storage = ctx["quant_storage"]
385-
param.module = ctx["module"]
386-
param.bnb_quantized = ctx["bnb_quantized"]
387-
return param
388-
389-
# Standard reconstruction for real tensors
390-
return Params4bit(
391-
data,
392-
requires_grad=data.requires_grad,
393-
quant_state=ctx["quant_state"],
394-
blocksize=ctx["blocksize"],
395-
compress_statistics=ctx["compress_statistics"],
396-
quant_type=ctx["quant_type"],
397-
quant_storage=ctx["quant_storage"],
398-
module=ctx["module"],
399-
bnb_quantized=ctx["bnb_quantized"],
400-
)
359+
# def __tensor_flatten__(self):
360+
# """Return data tensor and non-tensor context"""
361+
# ctx = {
362+
# "quant_state": self.quant_state,
363+
# "blocksize": self.blocksize,
364+
# "compress_statistics": self.compress_statistics,
365+
# "quant_type": self.quant_type,
366+
# "quant_storage": self.quant_storage,
367+
# "module": self.module,
368+
# "bnb_quantized": self.bnb_quantized,
369+
# }
370+
# return ["data"], ctx
371+
372+
# @staticmethod
373+
# def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
374+
# """Reconstruct Params4bit from components"""
375+
# data = inner_tensors["data"]
376+
377+
# # Special handling for FakeTensor reconstruction
378+
# if isinstance(data, torch._subclasses.FakeTensor):
379+
# param = data.as_subclass(Params4bit)
380+
# param.blocksize = ctx["blocksize"]
381+
# param.compress_statistics = ctx["compress_statistics"]
382+
# param.quant_type = ctx["quant_type"]
383+
# param.quant_state = ctx["quant_state"]
384+
# param.quant_storage = ctx["quant_storage"]
385+
# param.module = ctx["module"]
386+
# param.bnb_quantized = ctx["bnb_quantized"]
387+
# return param
388+
389+
# # Standard reconstruction for real tensors
390+
# return Params4bit(
391+
# data,
392+
# requires_grad=data.requires_grad,
393+
# quant_state=ctx["quant_state"],
394+
# blocksize=ctx["blocksize"],
395+
# compress_statistics=ctx["compress_statistics"],
396+
# quant_type=ctx["quant_type"],
397+
# quant_storage=ctx["quant_storage"],
398+
# module=ctx["module"],
399+
# bnb_quantized=ctx["bnb_quantized"],
400+
# )
401401

402402
@classmethod
403403
def __torch_function__(cls, func, types, args=(), kwargs=None):
404-
# Type preservation through ops
405-
result = super().__torch_function__(func, types, args, kwargs or {})
406-
if isinstance(result, torch.Tensor) and not isinstance(result, cls):
407-
return result.as_subclass(cls)
408-
return result
409-
410-
@classmethod
411-
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
412-
# Delegate to FakeTensor implementation when needed
413-
if any(isinstance(x, torch._subclasses.FakeTensor) for x in args):
414-
return torch._C.DispatchKey.Fake(func(*args, **(kwargs or {})))
415-
return super().__torch_dispatch__(func, types, args, kwargs)
404+
if kwargs is None:
405+
kwargs = {}
406+
with torch._C.DisableTorchFunctionSubclass():
407+
return func(*args, **kwargs)
408+
409+
# @classmethod
410+
# def __torch_function__(cls, func, types, args=(), kwargs=None):
411+
# # Type preservation through ops
412+
# result = super().__torch_function__(func, types, args, kwargs or {})
413+
# if isinstance(result, torch.Tensor) and not isinstance(result, cls):
414+
# return result.as_subclass(cls)
415+
# return result
416+
417+
# @classmethod
418+
# def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
419+
# # Delegate to FakeTensor implementation when needed
420+
# if any(isinstance(x, torch._subclasses.FakeTensor) for x in args):
421+
# return torch._C.DispatchKey.Fake(func(*args, **(kwargs or {})))
422+
# return super().__torch_dispatch__(func, types, args, kwargs)
416423

417424
def detach(self):
418425
"""Create new instance preserving quantization state"""

0 commit comments

Comments
 (0)