From b89c8996183a01a0058d59269eac897d8c05a5dd Mon Sep 17 00:00:00 2001
From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Date: Tue, 4 Nov 2025 13:28:31 -0500
Subject: [PATCH] Remove deprecated code
---
bitsandbytes/autograd/__init__.py | 1 -
bitsandbytes/autograd/_functions.py | 59 ------------------
bitsandbytes/functional.py | 96 -----------------------------
csrc/kernels.cu | 45 --------------
csrc/kernels.cuh | 7 ---
csrc/kernels.hip | 43 -------------
csrc/kernels_hip.cuh | 7 ---
csrc/ops.cu | 67 --------------------
csrc/ops.cuh | 9 ---
csrc/ops.hip | 81 ------------------------
csrc/ops_hip.cuh | 9 ---
csrc/pythonInterface.cpp | 4 --
tests/test_functional.py | 39 ------------
13 files changed, 467 deletions(-)
diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py
index f262d89ed..e69de29bb 100644
--- a/bitsandbytes/autograd/__init__.py
+++ b/bitsandbytes/autograd/__init__.py
@@ -1 +0,0 @@
-from ._functions import get_inverse_transform_indices, undo_layout
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 5391c8522..67420af3c 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,4 +1,3 @@
-from collections.abc import Callable
from dataclasses import dataclass
from math import prod
from typing import Optional
@@ -6,7 +5,6 @@
from warnings import warn
import torch
-from typing_extensions import deprecated
import bitsandbytes.functional as F
@@ -50,66 +48,9 @@ def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)
-@deprecated(
- "This function is deprecated and will be removed in a future release.",
- category=FutureWarning,
-)
-def get_inverse_transform_indices(
- transform_tile: Callable[[torch.Tensor], torch.Tensor],
- tile_size: tuple[int, int],
-):
- """
- Compute a permutation of indices that invert the specified (tiled) matrix transformation
-
- :param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
- :param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
- :note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
- :example: transform_tile function for the turing layout (bitsandbytes.functional as F)
- :returns: indices
- """
- d1, d2 = tile_size
- assert 0 < d1 * d2 < 2**64
- tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
- # encode each position in tile as a tuple of <= 8 unique bytes
- permuted_tile_indices = torch.zeros_like(tile_indices)
- for i in range(8):
- # select i-th byte, apply transformation and trace where each index ended up
- ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
- sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
- assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
- permuted_tile_i = transform_tile(sample_tile_i)
- ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
- permuted_tile_indices += ith_permuted_indices * (256**i)
- if d1 * d2 < 256**i:
- break # if all indices fit in i bytes, stop early
- return permuted_tile_indices
-
-
_is_compiling = torch.compiler.is_compiling
-@deprecated(
- "This function is deprecated and will be removed in a future release.",
- category=FutureWarning,
-)
-def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
- """
- Undo a tiled permutation such as turing or ampere layout
-
- :param permuted_tensor: torch tensor in a permuted layout
- :param tile_indices: reverse transformation indices, from get_inverse_transform_indices
- :return: contiguous row-major tensor
- """
- (rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
- assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
- tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
- outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
- outputs[tile_indices.flatten()] = tensor
- outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
- outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
- return outputs.reshape(rows, cols).contiguous()
-
-
@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None # TODO: remove
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 3d11276ad..abf160c7e 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -1795,102 +1795,6 @@ def int8_mm_dequant(
return result
-@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
-def get_colrow_absmax(
- A: torch.Tensor,
- row_stats: Optional[torch.Tensor] = None,
- col_stats: Optional[torch.Tensor] = None,
- nnz_block_ptr: Optional[torch.Tensor] = None,
- threshold=0.0,
-) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """ "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
-
- The row-wise and column-wise absmax values are determined.
-
- For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
-
-
- This function is useful for training, but for inference it is advised to use [`get_row_absmax`] instead.
- The column-wise quantization scales are not typically needed in inference scenarios.
-
-
- Args:
- A (`torch.Tensor` with dtype `torch.float16`): Input tensor.
- row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped.
- col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped.
- nnz_block_ptr (`torch.Tensor`, *optional*): Not used.
- threshold (`float`, *optional*):
- An optional threshold for sparse decomposition of outlier features.
- No outliers are held back when 0.0. Defaults to 0.0.
-
- Returns:
- `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing quantization statistics.
- - `torch.Tensor` with dtype `torch.float32`: The row-wise quantization statistics.
- - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization statistics.
- - `torch.Tensor` with dtype `torch.bool`, *optional*: A mask indicating the locations of outliers in the input tensor.
- """
- assert A.is_floating_point()
-
- outlier_mask = None
-
- if row_stats is None or col_stats is None:
- absA = A.abs().view(-1, A.shape[-1])
-
- if threshold > 0.0:
- # Filter outliers from stats when enabled
- outlier_mask = absA >= threshold
- absA.masked_fill_(outlier_mask, 0.0)
-
- if row_stats is None:
- # shape [rows]; unsqueeze(-1) gives [rows,1]
- # We have a CUDA kernel for row max, but not yet for cols.
- row_stats = get_row_absmax(A, threshold)
-
- if col_stats is None:
- # shape [cols]; unsqueeze(0) gives [1,cols]
- col_stats = absA.amax(dim=0, keepdim=False).float()
-
- return row_stats, col_stats, outlier_mask
-
-
-@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
-def get_row_absmax(A: torch.Tensor, threshold=0.0):
- """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
-
- For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
-
- Args:
- A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
- threshold (`float`, *optional*):
- An optional threshold for sparse decomposition of outlier features.
- No outliers are held back when 0.0. Defaults to 0.0.
-
- Returns:
- `torch.Tensor` with dtype `torch.float32`: The absolute maximum value for each row, with outliers ignored.
- """
-
- assert A.dtype == torch.float16
-
- rows = prod(A.shape[:-1])
- cols = A.shape[-1]
-
- row_stats = torch.empty((rows,), dtype=torch.float32, device=A.device)
-
- is_on_gpu([A])
-
- with _cuda_device_of(A):
- lib.cget_row_stats(
- get_ptr(A),
- get_ptr(row_stats),
- ct.c_float(threshold),
- ct.c_int32(rows),
- ct.c_int32(cols),
- _get_tensor_stream(A),
- )
-
- return row_stats
-
-
class COOSparseTensor:
def __init__(
self, rows: int, cols: int, nnz: int, rowidx: torch.Tensor, colidx: torch.Tensor, values: torch.Tensor
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index de48f5e82..ff122d376 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -1825,51 +1825,6 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
}
}
-template
-__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
- void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols) {
- using BlockReduceT = cub::BlockReduce;
-
- // One block per row.
- // Threads load column values in a striped arrangement.
- // e.g. t0 reads row[0], row[0+nthreads], ..
- // and t1 reads row[1], row[1+nthreads], ..
- // Each thread will determine its local absmax.
- // We then do a blockwise reduction to determine the row's absmax.
-
- __shared__ typename BlockReduceT::TempStorage temp_storage;
-
- const int row_id = blockIdx.x;
- const T* __restrict__ row_data = A + (row_id * cols);
-
- // Threads will read the row values in a striped access pattern and find a local absmax.
- float row_local_absmax = -FLT_MIN;
- for (int i = threadIdx.x; i < cols; i += THREADS) {
- const float absval = fabsf(row_data[i]);
-
- // For sparse decomposition, values outside of the threshold are not to be
- // included when calculating the row's absmax.
- if constexpr (SPARSE_DECOMP) {
- row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax);
- } else {
- row_local_absmax = fmaxf(row_local_absmax, absval);
- }
- }
-
- // Reduce thread-local absmax across the block.
- // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
- const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols);
- if (threadIdx.x == 0) {
- // Save our block's absmax to shared memory for the quantization step.
- rowStats[row_id] = row_absmax;
- }
-}
-
-template __global__ void
- kgetRowStats(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
-template __global__ void
- kgetRowStats(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
-
template __global__ void kInt8VectorQuant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols
);
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index f60e6fdd0..1d9b8b82e 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -109,16 +109,9 @@ __global__ void kdequant_mm_int32_fp16(
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);
-template
-__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);
-template
-__global__ void kTransformRowToFormat(
- char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols
-);
-
template
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template
diff --git a/csrc/kernels.hip b/csrc/kernels.hip
index fdeab46f2..5959bd055 100644
--- a/csrc/kernels.hip
+++ b/csrc/kernels.hip
@@ -1946,49 +1946,6 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
}
}
-template
-__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
-__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) {
- using BlockReduceT = hipcub::BlockReduce;
-
- // One block per row.
- // Threads load column values in a striped arrangement.
- // e.g. t0 reads row[0], row[0+nthreads], ..
- // and t1 reads row[1], row[1+nthreads], ..
- // Each thread will determine its local absmax.
- // We then do a blockwise reduction to determine the row's absmax.
-
- __shared__ typename BlockReduceT::TempStorage temp_storage;
-
- const int row_id = blockIdx.x;
- const T* __restrict__ row_data = A + (row_id * cols);
-
- // Threads will read the row values in a striped access pattern and find a local absmax.
- float row_local_absmax = -FLT_MIN;
- for (int i = threadIdx.x; i < cols; i += THREADS) {
- const float absval = fabsf(row_data[i]);
-
- // For sparse decomposition, values outside of the threshold are not to be
- // included when calculating the row's absmax.
- if constexpr (SPARSE_DECOMP) {
- row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax);
- } else {
- row_local_absmax = fmaxf(row_local_absmax, absval);
- }
- }
-
- // Reduce thread-local absmax across the block.
- // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
- const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols);
- if (threadIdx.x == 0) {
- // Save our block's absmax to shared memory for the quantization step.
- rowStats[row_id] = row_absmax;
- }
-}
-
-template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
-template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
-
template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh
index 00718071c..efd8a9048 100644
--- a/csrc/kernels_hip.cuh
+++ b/csrc/kernels_hip.cuh
@@ -111,16 +111,9 @@ __global__ void kdequant_mm_int32_fp16(
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);
-template
-__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);
-template
-__global__ void kTransformRowToFormat(
- char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols
-);
-
template
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template
diff --git a/csrc/ops.cu b/csrc/ops.cu
index aafab7522..37a3191bc 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -292,61 +292,6 @@ void strided_gemmex(
int roundoff(int v, int d) { return (v + d - 1) / d * d; }
-template cublasLtOrder_t get_order() {
- switch (ORDER) {
- case ROW:
- return CUBLASLT_ORDER_ROW;
- break;
- case COL:
- return CUBLASLT_ORDER_COL;
- break;
- case COL32:
- return CUBLASLT_ORDER_COL32;
- break;
- case COL_TURING:
- return CUBLASLT_ORDER_COL4_4R2_8C;
- break;
- case COL_AMPERE:
- return CUBLASLT_ORDER_COL32_2R_4R4;
- break;
- default:
- break;
- }
-
- return CUBLASLT_ORDER_ROW;
-}
-
-template cublasLtOrder_t get_order();
-template cublasLtOrder_t get_order();
-template cublasLtOrder_t get_order();
-template cublasLtOrder_t get_order();
-template cublasLtOrder_t get_order();
-
-template int get_leading_dim(int dim1, int dim2) {
- switch (ORDER) {
- case ROW:
- return dim2;
- break;
- case COL:
- return dim1;
- break;
- case COL32:
- // 32*row tiles
- return dim1 * 32;
- break;
- case COL_TURING:
- return 32 * roundoff(dim1, 8);
- break;
- case COL_AMPERE:
- // 32*32 tiles
- return 32 * roundoff(dim1, 32);
- break;
- default:
- return 0;
- break;
- }
-}
-
template
int igemmlt(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
@@ -449,14 +394,6 @@ void int8VectorQuant(
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
-void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
- if (threshold == 0.0)
- kgetRowStats<<>>(A, rowStats, threshold, rows, cols);
- else
- kgetRowStats<<>>(A, rowStats, threshold, rows, cols);
- CUDA_CHECK_RETURN(cudaPeekAtLastError());
-}
-
void spmm_coo(
cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
@@ -730,7 +667,3 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
template void percentileClipping(float* g, float* gnorm_vec, int step, const int n);
template void percentileClipping(half* g, float* gnorm_vec, int step, const int n);
-
-template int get_leading_dim(int dim1, int dim2);
-template int get_leading_dim(int dim1, int dim2);
-template int get_leading_dim(int dim1, int dim2);
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index 01e11ff31..9674ee055 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -69,14 +69,6 @@ typedef enum Optimizer_t {
ADEMAMIX = 6
} Optimizer_t;
-typedef enum Transform_t {
- ROW = 0,
- COL = 1,
- COL32 = 2,
- COL_TURING = 3,
- COL_AMPERE = 4,
-} Transform_t;
-
typedef enum DataType_t {
General8bit = 0,
FP4 = 1,
@@ -177,7 +169,6 @@ void cutlass_igemm(
void dequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
);
-void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void int8VectorQuant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
);
diff --git a/csrc/ops.hip b/csrc/ops.hip
index 2fe68f9bd..55cd43ee5 100644
--- a/csrc/ops.hip
+++ b/csrc/ops.hip
@@ -326,75 +326,6 @@ int roundoff(int v, int d) {
return (v + d - 1) / d * d;
}
-#ifdef NO_HIPBLASLT
-#else
-template hipblasLtOrder_t get_order()
-{
- switch(ORDER)
- {
- case ROW:
- return HIPBLASLT_ORDER_ROW;
- break;
- case COL:
- return HIPBLASLT_ORDER_COL;
- break;
- case COL32:
- //return HIPBLASLT_ORDER_COL32;
- return HIPBLASLT_ORDER_COL;
- break;
- case COL_TURING:
- //return HIPBLASLT_ORDER_COL4_4R2_8C;
- return HIPBLASLT_ORDER_COL;
- break;
- case COL_AMPERE:
- //return HIPBLASLT_ORDER_COL32_2R_4R4;
- return HIPBLASLT_ORDER_COL;
- break;
- default:
- break;
- }
-
- return HIPBLASLT_ORDER_ROW;
-}
-
-template hipblasLtOrder_t get_order();
-template hipblasLtOrder_t get_order();
-template hipblasLtOrder_t get_order();
-//template hipblasLtOrder_t get_order();
-//template hipblasLtOrder_t get_order();
-#endif
-
-template int get_leading_dim(int dim1, int dim2)
-{
- switch(ORDER)
- {
- case ROW:
- return dim2;
- break;
- case COL:
- return dim1;
- break;
- default:
- return dim1;
- break;
- /*case COL32:
- // 32*row tiles
- return dim1*32;
- break;
- case COL_TURING:
- return 32*roundoff(dim1, 8);
- break;
- case COL_AMPERE:
- // 32*32 tiles
- return 32*roundoff(dim1, 32);
- break;
- default:
- return 0;
- break;
- */
- }
-}
-
static std::string hipError_to_string(const hipError_t ret)
{
switch(ret)
@@ -603,14 +534,6 @@ void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
-void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) {
- if (threshold == 0.0)
- kgetRowStats<<>>(A, rowStats, threshold, rows, cols);
- else
- kgetRowStats<<>>(A, rowStats, threshold, rows, cols);
- CUDA_CHECK_RETURN(hipPeekAtLastError());
-}
-
void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{
@@ -835,7 +758,3 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
-
-template int get_leading_dim(int dim1, int dim2);
-template int get_leading_dim(int dim1, int dim2);
-template int get_leading_dim(int dim1, int dim2);
diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh
index 0f8db2ee4..7f9aa5d18 100644
--- a/csrc/ops_hip.cuh
+++ b/csrc/ops_hip.cuh
@@ -71,14 +71,6 @@ typedef enum Optimizer_t {
ADEMAMIX = 6,
} Optimizer_t;
-typedef enum Transform_t {
- ROW = 0,
- COL = 1,
- COL32 = 2,
- COL_TURING = 3,
- COL_AMPERE = 4,
-} Transform_t;
-
typedef enum DataType_t {
General8bit = 0,
FP4 = 1,
@@ -179,7 +171,6 @@ void cutlass_igemm(
void dequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, hipStream_t stream
);
-void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, hipStream_t stream);
void int8VectorQuant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, hipStream_t stream
);
diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp
index 28121240f..b62bca2ee 100644
--- a/csrc/pythonInterface.cpp
+++ b/csrc/pythonInterface.cpp
@@ -641,10 +641,6 @@ void cdequant_mm_int32_fp16(
dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream);
}
-void cget_row_stats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
- getRowStats(A, rowStats, threshold, rows, cols, stream);
-}
-
void cint8_vector_quant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
) {
diff --git a/tests/test_functional.py b/tests/test_functional.py
index e045be28c..5a62fa1d8 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -704,45 +704,6 @@ def test_dequant_mm(self, device, dim1, dim4, dims, has_bias):
n = C5.numel()
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))
- @pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
- @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
- @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
- @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp"))
- @pytest.mark.deprecated
- def test_colrow_absmax(self, dim1, dim2, dims, threshold):
- for i in range(k):
- A = torch.randn(dim1, dim2, device="cuda").half()
-
- assert dims == 2
-
- row_stats1, _ = torch.abs(A.float()).max(1)
- col_stats1, _ = torch.abs(A.float()).max(0)
-
- if threshold > 0.0:
- A_truncated = A.clone()
- A_truncated[torch.abs(A_truncated) >= threshold] = 0.0
- row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
- col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
-
- row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
-
- nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten()
- nnz_block_ptr1 = torch.zeros(
- nnz_rows1_counts.shape[0] + 1,
- dtype=nnz_rows1_counts.dtype,
- device=nnz_rows1_counts.device,
- )
- nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
-
- torch.testing.assert_close(col_stats1_trunc, col_stats2)
- torch.testing.assert_close(row_stats1_trunc, row_stats2)
- # torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2)
- else:
- row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
- assert nnz_block_ptr2 is None
- torch.testing.assert_close(col_stats1, col_stats2)
- torch.testing.assert_close(row_stats1, row_stats2)
-
@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
@pytest.mark.deprecated