diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py index f262d89ed..e69de29bb 100644 --- a/bitsandbytes/autograd/__init__.py +++ b/bitsandbytes/autograd/__init__.py @@ -1 +0,0 @@ -from ._functions import get_inverse_transform_indices, undo_layout diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 5391c8522..67420af3c 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,4 +1,3 @@ -from collections.abc import Callable from dataclasses import dataclass from math import prod from typing import Optional @@ -6,7 +5,6 @@ from warnings import warn import torch -from typing_extensions import deprecated import bitsandbytes.functional as F @@ -50,66 +48,9 @@ def get_current_outlier_idx(self): return torch.Tensor(list(self.outliers)).to(torch.int64) -@deprecated( - "This function is deprecated and will be removed in a future release.", - category=FutureWarning, -) -def get_inverse_transform_indices( - transform_tile: Callable[[torch.Tensor], torch.Tensor], - tile_size: tuple[int, int], -): - """ - Compute a permutation of indices that invert the specified (tiled) matrix transformation - - :param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2] - :param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere - :note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size - :example: transform_tile function for the turing layout (bitsandbytes.functional as F) - :returns: indices - """ - d1, d2 = tile_size - assert 0 < d1 * d2 < 2**64 - tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2) - # encode each position in tile as a tuple of <= 8 unique bytes - permuted_tile_indices = torch.zeros_like(tile_indices) - for i in range(8): - # select i-th byte, apply transformation and trace where each index ended up - ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256 - sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous() - assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow" - permuted_tile_i = transform_tile(sample_tile_i) - ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128 - permuted_tile_indices += ith_permuted_indices * (256**i) - if d1 * d2 < 256**i: - break # if all indices fit in i bytes, stop early - return permuted_tile_indices - - _is_compiling = torch.compiler.is_compiling -@deprecated( - "This function is deprecated and will be removed in a future release.", - category=FutureWarning, -) -def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: - """ - Undo a tiled permutation such as turing or ampere layout - - :param permuted_tensor: torch tensor in a permuted layout - :param tile_indices: reverse transformation indices, from get_inverse_transform_indices - :return: contiguous row-major tensor - """ - (rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape - assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles" - tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t() - outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda - outputs[tile_indices.flatten()] = tensor - outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows) - outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols) - return outputs.reshape(rows, cols).contiguous() - - @dataclass class MatmulLtState: _tile_indices: Optional[torch.Tensor] = None # TODO: remove diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 3d11276ad..abf160c7e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1795,102 +1795,6 @@ def int8_mm_dequant( return result -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def get_colrow_absmax( - A: torch.Tensor, - row_stats: Optional[torch.Tensor] = None, - col_stats: Optional[torch.Tensor] = None, - nnz_block_ptr: Optional[torch.Tensor] = None, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. - - The row-wise and column-wise absmax values are determined. - - For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339). - - - This function is useful for training, but for inference it is advised to use [`get_row_absmax`] instead. - The column-wise quantization scales are not typically needed in inference scenarios. - - - Args: - A (`torch.Tensor` with dtype `torch.float16`): Input tensor. - row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped. - col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped. - nnz_block_ptr (`torch.Tensor`, *optional*): Not used. - threshold (`float`, *optional*): - An optional threshold for sparse decomposition of outlier features. - No outliers are held back when 0.0. Defaults to 0.0. - - Returns: - `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing quantization statistics. - - `torch.Tensor` with dtype `torch.float32`: The row-wise quantization statistics. - - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization statistics. - - `torch.Tensor` with dtype `torch.bool`, *optional*: A mask indicating the locations of outliers in the input tensor. - """ - assert A.is_floating_point() - - outlier_mask = None - - if row_stats is None or col_stats is None: - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - if row_stats is None: - # shape [rows]; unsqueeze(-1) gives [rows,1] - # We have a CUDA kernel for row max, but not yet for cols. - row_stats = get_row_absmax(A, threshold) - - if col_stats is None: - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return row_stats, col_stats, outlier_mask - - -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def get_row_absmax(A: torch.Tensor, threshold=0.0): - """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. - - For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339). - - Args: - A (`torch.Tensor` with dtype `torch.float16`): The input matrix. - threshold (`float`, *optional*): - An optional threshold for sparse decomposition of outlier features. - No outliers are held back when 0.0. Defaults to 0.0. - - Returns: - `torch.Tensor` with dtype `torch.float32`: The absolute maximum value for each row, with outliers ignored. - """ - - assert A.dtype == torch.float16 - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty((rows,), dtype=torch.float32, device=A.device) - - is_on_gpu([A]) - - with _cuda_device_of(A): - lib.cget_row_stats( - get_ptr(A), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - return row_stats - - class COOSparseTensor: def __init__( self, rows: int, cols: int, nnz: int, rowidx: torch.Tensor, colidx: torch.Tensor, values: torch.Tensor diff --git a/csrc/kernels.cu b/csrc/kernels.cu index de48f5e82..ff122d376 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1825,51 +1825,6 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ } } -template -__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ - void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols) { - using BlockReduceT = cub::BlockReduce; - - // One block per row. - // Threads load column values in a striped arrangement. - // e.g. t0 reads row[0], row[0+nthreads], .. - // and t1 reads row[1], row[1+nthreads], .. - // Each thread will determine its local absmax. - // We then do a blockwise reduction to determine the row's absmax. - - __shared__ typename BlockReduceT::TempStorage temp_storage; - - const int row_id = blockIdx.x; - const T* __restrict__ row_data = A + (row_id * cols); - - // Threads will read the row values in a striped access pattern and find a local absmax. - float row_local_absmax = -FLT_MIN; - for (int i = threadIdx.x; i < cols; i += THREADS) { - const float absval = fabsf(row_data[i]); - - // For sparse decomposition, values outside of the threshold are not to be - // included when calculating the row's absmax. - if constexpr (SPARSE_DECOMP) { - row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); - } else { - row_local_absmax = fmaxf(row_local_absmax, absval); - } - } - - // Reduce thread-local absmax across the block. - // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY - const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols); - if (threadIdx.x == 0) { - // Save our block's absmax to shared memory for the quantization step. - rowStats[row_id] = row_absmax; - } -} - -template __global__ void - kgetRowStats(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols); -template __global__ void - kgetRowStats(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols); - template __global__ void kInt8VectorQuant( half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols ); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index f60e6fdd0..1d9b8b82e 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -109,16 +109,9 @@ __global__ void kdequant_mm_int32_fp16( half* __restrict__ const bias, const int numRows, const int numCols, const int n ); -template -__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols); template __global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols); -template -__global__ void kTransformRowToFormat( - char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols -); - template __global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc); template diff --git a/csrc/kernels.hip b/csrc/kernels.hip index fdeab46f2..5959bd055 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -1946,49 +1946,6 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat } } -template -__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) -__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) { - using BlockReduceT = hipcub::BlockReduce; - - // One block per row. - // Threads load column values in a striped arrangement. - // e.g. t0 reads row[0], row[0+nthreads], .. - // and t1 reads row[1], row[1+nthreads], .. - // Each thread will determine its local absmax. - // We then do a blockwise reduction to determine the row's absmax. - - __shared__ typename BlockReduceT::TempStorage temp_storage; - - const int row_id = blockIdx.x; - const T* __restrict__ row_data = A + (row_id * cols); - - // Threads will read the row values in a striped access pattern and find a local absmax. - float row_local_absmax = -FLT_MIN; - for (int i = threadIdx.x; i < cols; i += THREADS) { - const float absval = fabsf(row_data[i]); - - // For sparse decomposition, values outside of the threshold are not to be - // included when calculating the row's absmax. - if constexpr (SPARSE_DECOMP) { - row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); - } else { - row_local_absmax = fmaxf(row_local_absmax, absval); - } - } - - // Reduce thread-local absmax across the block. - // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY - const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols); - if (threadIdx.x == 0) { - // Save our block's absmax to shared memory for the quantization step. - rowStats[row_id] = row_absmax; - } -} - -template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); -template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); - template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh index 00718071c..efd8a9048 100644 --- a/csrc/kernels_hip.cuh +++ b/csrc/kernels_hip.cuh @@ -111,16 +111,9 @@ __global__ void kdequant_mm_int32_fp16( half* __restrict__ const bias, const int numRows, const int numCols, const int n ); -template -__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols); template __global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols); -template -__global__ void kTransformRowToFormat( - char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols -); - template __global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc); template diff --git a/csrc/ops.cu b/csrc/ops.cu index aafab7522..37a3191bc 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -292,61 +292,6 @@ void strided_gemmex( int roundoff(int v, int d) { return (v + d - 1) / d * d; } -template cublasLtOrder_t get_order() { - switch (ORDER) { - case ROW: - return CUBLASLT_ORDER_ROW; - break; - case COL: - return CUBLASLT_ORDER_COL; - break; - case COL32: - return CUBLASLT_ORDER_COL32; - break; - case COL_TURING: - return CUBLASLT_ORDER_COL4_4R2_8C; - break; - case COL_AMPERE: - return CUBLASLT_ORDER_COL32_2R_4R4; - break; - default: - break; - } - - return CUBLASLT_ORDER_ROW; -} - -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); - -template int get_leading_dim(int dim1, int dim2) { - switch (ORDER) { - case ROW: - return dim2; - break; - case COL: - return dim1; - break; - case COL32: - // 32*row tiles - return dim1 * 32; - break; - case COL_TURING: - return 32 * roundoff(dim1, 8); - break; - case COL_AMPERE: - // 32*32 tiles - return 32 * roundoff(dim1, 32); - break; - default: - return 0; - break; - } -} - template int igemmlt( cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, @@ -449,14 +394,6 @@ void int8VectorQuant( CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) { - if (threshold == 0.0) - kgetRowStats<<>>(A, rowStats, threshold, rows, cols); - else - kgetRowStats<<>>(A, rowStats, threshold, rows, cols); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - void spmm_coo( cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half* B, int ldc, half* C, bool transposed_B @@ -730,7 +667,3 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); template void percentileClipping(float* g, float* gnorm_vec, int step, const int n); template void percentileClipping(half* g, float* gnorm_vec, int step, const int n); - -template int get_leading_dim(int dim1, int dim2); -template int get_leading_dim(int dim1, int dim2); -template int get_leading_dim(int dim1, int dim2); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 01e11ff31..9674ee055 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -69,14 +69,6 @@ typedef enum Optimizer_t { ADEMAMIX = 6 } Optimizer_t; -typedef enum Transform_t { - ROW = 0, - COL = 1, - COL32 = 2, - COL_TURING = 3, - COL_AMPERE = 4, -} Transform_t; - typedef enum DataType_t { General8bit = 0, FP4 = 1, @@ -177,7 +169,6 @@ void cutlass_igemm( void dequant_mm_int32_fp16( int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream ); -void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream); void int8VectorQuant( half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream ); diff --git a/csrc/ops.hip b/csrc/ops.hip index 2fe68f9bd..55cd43ee5 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -326,75 +326,6 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; } -#ifdef NO_HIPBLASLT -#else -template hipblasLtOrder_t get_order() -{ - switch(ORDER) - { - case ROW: - return HIPBLASLT_ORDER_ROW; - break; - case COL: - return HIPBLASLT_ORDER_COL; - break; - case COL32: - //return HIPBLASLT_ORDER_COL32; - return HIPBLASLT_ORDER_COL; - break; - case COL_TURING: - //return HIPBLASLT_ORDER_COL4_4R2_8C; - return HIPBLASLT_ORDER_COL; - break; - case COL_AMPERE: - //return HIPBLASLT_ORDER_COL32_2R_4R4; - return HIPBLASLT_ORDER_COL; - break; - default: - break; - } - - return HIPBLASLT_ORDER_ROW; -} - -template hipblasLtOrder_t get_order(); -template hipblasLtOrder_t get_order(); -template hipblasLtOrder_t get_order(); -//template hipblasLtOrder_t get_order(); -//template hipblasLtOrder_t get_order(); -#endif - -template int get_leading_dim(int dim1, int dim2) -{ - switch(ORDER) - { - case ROW: - return dim2; - break; - case COL: - return dim1; - break; - default: - return dim1; - break; - /*case COL32: - // 32*row tiles - return dim1*32; - break; - case COL_TURING: - return 32*roundoff(dim1, 8); - break; - case COL_AMPERE: - // 32*32 tiles - return 32*roundoff(dim1, 32); - break; - default: - return 0; - break; - */ - } -} - static std::string hipError_to_string(const hipError_t ret) { switch(ret) @@ -603,14 +534,6 @@ void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float CUDA_CHECK_RETURN(hipPeekAtLastError()); } -void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) { - if (threshold == 0.0) - kgetRowStats<<>>(A, rowStats, threshold, rows, cols); - else - kgetRowStats<<>>(A, rowStats, threshold, rows, cols); - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} - void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { @@ -835,7 +758,3 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); - -template int get_leading_dim(int dim1, int dim2); -template int get_leading_dim(int dim1, int dim2); -template int get_leading_dim(int dim1, int dim2); diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index 0f8db2ee4..7f9aa5d18 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -71,14 +71,6 @@ typedef enum Optimizer_t { ADEMAMIX = 6, } Optimizer_t; -typedef enum Transform_t { - ROW = 0, - COL = 1, - COL32 = 2, - COL_TURING = 3, - COL_AMPERE = 4, -} Transform_t; - typedef enum DataType_t { General8bit = 0, FP4 = 1, @@ -179,7 +171,6 @@ void cutlass_igemm( void dequant_mm_int32_fp16( int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, hipStream_t stream ); -void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, hipStream_t stream); void int8VectorQuant( half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, hipStream_t stream ); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 28121240f..b62bca2ee 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -641,10 +641,6 @@ void cdequant_mm_int32_fp16( dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); } -void cget_row_stats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) { - getRowStats(A, rowStats, threshold, rows, cols, stream); -} - void cint8_vector_quant( half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream ) { diff --git a/tests/test_functional.py b/tests/test_functional.py index e045be28c..5a62fa1d8 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -704,45 +704,6 @@ def test_dequant_mm(self, device, dim1, dim4, dims, has_bias): n = C5.numel() assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) - @pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1")) - @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) - @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) - @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp")) - @pytest.mark.deprecated - def test_colrow_absmax(self, dim1, dim2, dims, threshold): - for i in range(k): - A = torch.randn(dim1, dim2, device="cuda").half() - - assert dims == 2 - - row_stats1, _ = torch.abs(A.float()).max(1) - col_stats1, _ = torch.abs(A.float()).max(0) - - if threshold > 0.0: - A_truncated = A.clone() - A_truncated[torch.abs(A_truncated) >= threshold] = 0.0 - row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1) - col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0) - - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) - - nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten() - nnz_block_ptr1 = torch.zeros( - nnz_rows1_counts.shape[0] + 1, - dtype=nnz_rows1_counts.dtype, - device=nnz_rows1_counts.device, - ) - nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) - - torch.testing.assert_close(col_stats1_trunc, col_stats2) - torch.testing.assert_close(row_stats1_trunc, row_stats2) - # torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2) - else: - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) - assert nnz_block_ptr2 is None - torch.testing.assert_close(col_stats1, col_stats2) - torch.testing.assert_close(row_stats1, row_stats2) - @pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2")) @pytest.mark.deprecated