@@ -223,18 +223,18 @@ def __new__(
223223 data = torch .empty (0 )
224224
225225 # Handle FakeTensor creation during dynamo tracing
226- if torch ._dynamo .is_compiling () and not isinstance (data , cls ):
227- if isinstance (data , torch ._subclasses .FakeTensor ):
228- param = data .as_subclass (cls )
229- param .requires_grad = requires_grad
230- param .quant_state = quant_state
231- param .blocksize = blocksize
232- param .compress_statistics = compress_statistics
233- param .quant_type = quant_type
234- param .quant_storage = quant_storage
235- param .module = module
236- param .bnb_quantized = bnb_quantized
237- return param
226+ # if torch._dynamo.is_compiling() and not isinstance(data, cls):
227+ # if isinstance(data, torch._subclasses.FakeTensor):
228+ # param = data.as_subclass(cls)
229+ # param.requires_grad = requires_grad
230+ # param.quant_state = quant_state
231+ # param.blocksize = blocksize
232+ # param.compress_statistics = compress_statistics
233+ # param.quant_type = quant_type
234+ # param.quant_storage = quant_storage
235+ # param.module = module
236+ # param.bnb_quantized = bnb_quantized
237+ # return param
238238
239239 # Standard initialization for real tensors
240240 self = torch .Tensor ._make_subclass (cls , data , requires_grad )
@@ -356,63 +356,70 @@ def to(self, *args, **kwargs):
356356 bnb_quantized = self .bnb_quantized ,
357357 )
358358
359- def __tensor_flatten__ (self ):
360- """Return data tensor and non-tensor context"""
361- ctx = {
362- "quant_state" : self .quant_state ,
363- "blocksize" : self .blocksize ,
364- "compress_statistics" : self .compress_statistics ,
365- "quant_type" : self .quant_type ,
366- "quant_storage" : self .quant_storage ,
367- "module" : self .module ,
368- "bnb_quantized" : self .bnb_quantized ,
369- }
370- return ["data" ], ctx
371-
372- @staticmethod
373- def __tensor_unflatten__ (inner_tensors , ctx , outer_size , outer_stride ):
374- """Reconstruct Params4bit from components"""
375- data = inner_tensors ["data" ]
376-
377- # Special handling for FakeTensor reconstruction
378- if isinstance (data , torch ._subclasses .FakeTensor ):
379- param = data .as_subclass (Params4bit )
380- param .blocksize = ctx ["blocksize" ]
381- param .compress_statistics = ctx ["compress_statistics" ]
382- param .quant_type = ctx ["quant_type" ]
383- param .quant_state = ctx ["quant_state" ]
384- param .quant_storage = ctx ["quant_storage" ]
385- param .module = ctx ["module" ]
386- param .bnb_quantized = ctx ["bnb_quantized" ]
387- return param
388-
389- # Standard reconstruction for real tensors
390- return Params4bit (
391- data ,
392- requires_grad = data .requires_grad ,
393- quant_state = ctx ["quant_state" ],
394- blocksize = ctx ["blocksize" ],
395- compress_statistics = ctx ["compress_statistics" ],
396- quant_type = ctx ["quant_type" ],
397- quant_storage = ctx ["quant_storage" ],
398- module = ctx ["module" ],
399- bnb_quantized = ctx ["bnb_quantized" ],
400- )
359+ # def __tensor_flatten__(self):
360+ # """Return data tensor and non-tensor context"""
361+ # ctx = {
362+ # "quant_state": self.quant_state,
363+ # "blocksize": self.blocksize,
364+ # "compress_statistics": self.compress_statistics,
365+ # "quant_type": self.quant_type,
366+ # "quant_storage": self.quant_storage,
367+ # "module": self.module,
368+ # "bnb_quantized": self.bnb_quantized,
369+ # }
370+ # return ["data"], ctx
371+
372+ # @staticmethod
373+ # def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
374+ # """Reconstruct Params4bit from components"""
375+ # data = inner_tensors["data"]
376+
377+ # # Special handling for FakeTensor reconstruction
378+ # if isinstance(data, torch._subclasses.FakeTensor):
379+ # param = data.as_subclass(Params4bit)
380+ # param.blocksize = ctx["blocksize"]
381+ # param.compress_statistics = ctx["compress_statistics"]
382+ # param.quant_type = ctx["quant_type"]
383+ # param.quant_state = ctx["quant_state"]
384+ # param.quant_storage = ctx["quant_storage"]
385+ # param.module = ctx["module"]
386+ # param.bnb_quantized = ctx["bnb_quantized"]
387+ # return param
388+
389+ # # Standard reconstruction for real tensors
390+ # return Params4bit(
391+ # data,
392+ # requires_grad=data.requires_grad,
393+ # quant_state=ctx["quant_state"],
394+ # blocksize=ctx["blocksize"],
395+ # compress_statistics=ctx["compress_statistics"],
396+ # quant_type=ctx["quant_type"],
397+ # quant_storage=ctx["quant_storage"],
398+ # module=ctx["module"],
399+ # bnb_quantized=ctx["bnb_quantized"],
400+ # )
401401
402402 @classmethod
403403 def __torch_function__ (cls , func , types , args = (), kwargs = None ):
404- # Type preservation through ops
405- result = super ().__torch_function__ (func , types , args , kwargs or {})
406- if isinstance (result , torch .Tensor ) and not isinstance (result , cls ):
407- return result .as_subclass (cls )
408- return result
409-
410- @classmethod
411- def __torch_dispatch__ (cls , func , types , args = (), kwargs = None ):
412- # Delegate to FakeTensor implementation when needed
413- if any (isinstance (x , torch ._subclasses .FakeTensor ) for x in args ):
414- return torch ._C .DispatchKey .Fake (func (* args , ** (kwargs or {})))
415- return super ().__torch_dispatch__ (func , types , args , kwargs )
404+ if kwargs is None :
405+ kwargs = {}
406+ with torch ._C .DisableTorchFunctionSubclass ():
407+ return func (* args , ** kwargs )
408+
409+ # @classmethod
410+ # def __torch_function__(cls, func, types, args=(), kwargs=None):
411+ # # Type preservation through ops
412+ # result = super().__torch_function__(func, types, args, kwargs or {})
413+ # if isinstance(result, torch.Tensor) and not isinstance(result, cls):
414+ # return result.as_subclass(cls)
415+ # return result
416+
417+ # @classmethod
418+ # def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
419+ # # Delegate to FakeTensor implementation when needed
420+ # if any(isinstance(x, torch._subclasses.FakeTensor) for x in args):
421+ # return torch._C.DispatchKey.Fake(func(*args, **(kwargs or {})))
422+ # return super().__torch_dispatch__(func, types, args, kwargs)
416423
417424 def detach (self ):
418425 """Create new instance preserving quantization state"""
0 commit comments