Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions benchmarking/switchback/README.md

This file was deleted.

60 changes: 0 additions & 60 deletions benchmarking/switchback/info_a100_py2.jsonl

This file was deleted.

151 changes: 0 additions & 151 deletions benchmarking/switchback/make_plot_with_jsonl.py

This file was deleted.

Binary file removed benchmarking/switchback/plot_with_info.pdf
Binary file not shown.
160 changes: 0 additions & 160 deletions benchmarking/switchback/speed_benchmark.py

This file was deleted.

2 changes: 1 addition & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from . import _ops, research, utils
from . import _ops, utils
from .autograd._functions import (
MatmulLtState,
matmul,
Expand Down
26 changes: 14 additions & 12 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,12 @@ def get_current_outlier_idx(self):

@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None # TODO: remove

force_no_igemmlt: bool = False

CB: Optional[torch.Tensor] = None
CxB: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SB: Optional[torch.Tensor] = None
SCB: Optional[torch.Tensor] = None

CxBt: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SBt: Optional[torch.Tensor] = None
CBt: Optional[torch.Tensor] = None

Expand All @@ -75,22 +71,29 @@ class MatmulLtState:
is_training = True
has_fp16_weights = True
use_pool = False
formatB = "row" # TODO: Deprecate/remove

# Deprecated attributes kept for downstream compatibility (TGI, vLLM).
# These are always None and will be fully removed in the next release.
_deprecated_fields = frozenset({"CxB", "CxBt", "formatB", "_tile_indices"})

def __getattr__(self, name):
if name in MatmulLtState._deprecated_fields:
warnings.warn(
f"MatmulLtState.{name} is deprecated and will be removed in the next bitsandbytes release.",
FutureWarning,
stacklevel=2,
)
return None
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def reset_grads(self):
self.CB = None
self.CxB = None
self.SB = None
self.SCB = None

self.CxBt = None
self.SBt = None
self.CBt = None

@property
def tile_indices(self):
raise ValueError("tile_indices is no longer supported.")


class MatMul8bitLt(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -293,7 +296,6 @@ def backward(ctx, grad_output):

class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

@staticmethod
def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
Expand Down
3 changes: 1 addition & 2 deletions bitsandbytes/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import torch

try:
import triton.language as tl # noqa: F401

import triton # noqa: F401
import triton.language as tl # noqa: F401

triton_available = True
except ImportError:
Expand Down
Loading