Skip to content

Commit b944b14

Browse files
committed
pre-commit pass balck check
1 parent 29f1188 commit b944b14

File tree

5 files changed

+248
-159
lines changed

5 files changed

+248
-159
lines changed

ucm/integration/vllm/ucm_sparse/factory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,6 @@ def create_sparse_method(
4444

4545
# Register available sparse methods
4646
UcmSparseFactory.register_sparse_method("ESA", "ucm.ucm_sparse.esa", "ESA")
47-
UcmSparseFactory.register_sparse_method("KvComp", "ucm.sandbox.sparse.kvcomp.kvcomp", "KvComp")
47+
UcmSparseFactory.register_sparse_method(
48+
"KvComp", "ucm.sandbox.sparse.kvcomp.kvcomp", "KvComp"
49+
)

ucm/sandbox/sparse/kvcomp/hash_encoder.py

Lines changed: 92 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,23 @@
2323
"""
2424

2525
import torch
26-
if hasattr(torch, 'npu') and torch.npu.is_available():
26+
27+
if hasattr(torch, "npu") and torch.npu.is_available():
2728
import torch_npu
2829

2930
from ucm.logger import init_logger
31+
3032
logger = init_logger(__name__)
3133

34+
3235
class HashEncoder:
3336
"""
3437
HashEncoder converts a float tensor to a binary hash code tensor,
3538
and it packs every 8 bits into a uint8 number.
3639
"""
40+
3741
def __init__(
38-
self,
39-
input_dim: int,
40-
hash_bits: int,
41-
dtype: torch.dtype,
42-
device: torch.device
42+
self, input_dim: int, hash_bits: int, dtype: torch.dtype, device: torch.device
4343
) -> None:
4444
self.input_dim = input_dim
4545

@@ -49,56 +49,70 @@ def __init__(
4949
self.hash_bits = hash_bits
5050

5151
# number of uint8 numbers to store hash_bits bits
52-
self.hash_numbers = self.hash_bits // 8
52+
self.hash_numbers = self.hash_bits // 8
5353

5454
self.dtype = dtype
5555
self.device = device
5656

57-
if self.device.type == 'npu':
57+
if self.device.type == "npu":
5858
if dtype not in [torch.float16, torch.float32, torch.float64]:
59-
logger.warning("NPU only supports float16, float32 and float64 for hash_weights")
59+
logger.warning(
60+
"NPU only supports float16, float32 and float64 for hash_weights"
61+
)
6062
logger.warning("automatically using float16 for hash_weights now")
61-
self.dtype = torch.float16
62-
63-
self.hash_weights = torch.normal(mean=0, std=2, size=(self.input_dim, self.hash_bits), dtype=self.dtype, device=self.device)
63+
self.dtype = torch.float16
6464

65-
if self.device.type == 'cuda' or self.device.type == 'cpu':
65+
self.hash_weights = torch.normal(
66+
mean=0,
67+
std=2,
68+
size=(self.input_dim, self.hash_bits),
69+
dtype=self.dtype,
70+
device=self.device,
71+
)
72+
73+
if self.device.type == "cuda" or self.device.type == "cpu":
6674
self._init_bit_masks()
67-
68-
def set_hash_weight(
69-
self,
70-
hash_weights: torch.Tensor
71-
) -> None:
75+
76+
def set_hash_weight(self, hash_weights: torch.Tensor) -> None:
7277
if hash_weights.shape != (self.input_dim, self.hash_bits):
73-
raise ValueError(f"hash_weights shape {hash_weights.shape} does not match required shape {(self.input_dim, self.hash_bits)}")
78+
raise ValueError(
79+
f"hash_weights shape {hash_weights.shape} does not match required shape {(self.input_dim, self.hash_bits)}"
80+
)
7481
if hash_weights.dtype != self.dtype:
75-
raise ValueError(f"hash_weights dtype {hash_weights.dtype} does not match required dtype {self.dtype}")
82+
raise ValueError(
83+
f"hash_weights dtype {hash_weights.dtype} does not match required dtype {self.dtype}"
84+
)
7685
if hash_weights.device != self.device:
77-
raise ValueError(f"hash_weights device {hash_weights.device} does not match required device {self.device}")
78-
86+
raise ValueError(
87+
f"hash_weights device {hash_weights.device} does not match required device {self.device}"
88+
)
89+
7990
self.hash_weights.copy_(hash_weights)
80-
91+
8192
def _init_bit_masks(self) -> None:
82-
self.bit_masks = torch.pow(2, torch.arange(8, dtype=torch.uint8, device=self.device))
93+
self.bit_masks = torch.pow(
94+
2, torch.arange(8, dtype=torch.uint8, device=self.device)
95+
)
8396
# shape (1, 1, 8)
84-
self.bit_masks = self.bit_masks.unsqueeze(0).unsqueeze(0)
85-
86-
def compute_hash(
87-
self,
88-
x: torch.Tensor
89-
) -> torch.Tensor:
97+
self.bit_masks = self.bit_masks.unsqueeze(0).unsqueeze(0)
98+
99+
def compute_hash(self, x: torch.Tensor) -> torch.Tensor:
90100
"""
91101
Compute the hash code for input tensor x.
92-
Args:
102+
Args:
93103
x: input tensor of shape (..., input_dim)
94104
Returns:
95-
A tensor of shape (..., hash_numbers=hash_bits // 8) representing the hash codes.
105+
A tensor of shape (..., hash_numbers=hash_bits // 8) representing the hash codes.
96106
Each element is a uint8 number representing 8 bits of the hash code.
97107
"""
98108
if x.shape[-1] != self.input_dim:
99-
raise ValueError(f"x must be of shape (..., {self.input_dim}), but got {x.shape}")
109+
raise ValueError(
110+
f"x must be of shape (..., {self.input_dim}), but got {x.shape}"
111+
)
100112
if x.device != self.device:
101-
raise ValueError(f"x device {x.device} does not match required device {self.device}")
113+
raise ValueError(
114+
f"x device {x.device} does not match required device {self.device}"
115+
)
102116

103117
# original shape without the last dimension
104118
# e.g. x.shape=[s1,s2,s3,input_dim], orig_shape=[s1,s2,s3]
@@ -112,78 +126,87 @@ def compute_hash(
112126

113127
# [N, hash_bits]
114128
xW = torch.matmul(x_flat, self.hash_weights)
115-
129+
116130
# [N * hash_bits]
117131
xW_flat = xW.view(-1)
118132

119-
if self.device.type == 'npu':
133+
if self.device.type == "npu":
120134
# [N*hash_numbers], where hash_numbers = hash_bits // 8
121135
packed_codes_flat = torch_npu.npu_sign_bits_pack(xW_flat, size=1)
122-
elif self.device.type == 'cuda' or self.device.type == 'cpu':
136+
elif self.device.type == "cuda" or self.device.type == "cpu":
123137
# (TODO) improve performance later on CUDA ops and CPU SIMD instructions
124138
# [N, hash_bits]
125139
projected = (xW > 0).to(torch.uint8)
126140

127141
# [N, hash_numbers, 8]
128-
binary_codes = projected.view(-1, self.hash_numbers, 8)
142+
binary_codes = projected.view(-1, self.hash_numbers, 8)
129143

130144
# binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8]
131145
# then sum along the last dimension to get [N, hash_numbers]
132-
packed_codes_flat = torch.sum(binary_codes * self.bit_masks, dim=-1, dtype=torch.uint8) # [N, hash_numbers]
146+
packed_codes_flat = torch.sum(
147+
binary_codes * self.bit_masks, dim=-1, dtype=torch.uint8
148+
) # [N, hash_numbers]
133149
packed_codes_flat = packed_codes_flat.view(-1) # [N * hash_numbers]
134150
else:
135151
raise ValueError(f"Unsupported device type: {self.device.type}")
136152

137153
# e.g., [s1, s2, s3, hash_numbers]
138154
out_shape = orig_shape + (self.hash_numbers,)
139-
packed_codes = packed_codes_flat.view(out_shape)
155+
packed_codes = packed_codes_flat.view(out_shape)
140156

141157
return packed_codes
142158

143-
def _unpack_hash(
144-
self,
145-
packed_codes: torch.Tensor
146-
) -> torch.Tensor:
159+
def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor:
147160
"""
148161
Unpack the hash codes to +1 or -1 bits.
149162
Args:
150163
packed_codes: input tensor of shape (..., hash_numbers), dtype=torch.uint8
151164
Returns:
152-
A tensor of shape (..., hash_bits=hash_numbers*8) representing the unpacked bits.
165+
A tensor of shape (..., hash_bits=hash_numbers*8) representing the unpacked bits.
153166
Each element is either -1 or 1.
154167
"""
155168
if packed_codes.shape[-1] != self.hash_numbers:
156-
raise ValueError(f"packed_codes must be of shape (..., {self.hash_numbers}), but got {packed_codes.shape}")
169+
raise ValueError(
170+
f"packed_codes must be of shape (..., {self.hash_numbers}), but got {packed_codes.shape}"
171+
)
157172
if packed_codes.device != self.device:
158-
raise ValueError(f"packed_codes device {packed_codes.device} does not match required device {self.device}")
173+
raise ValueError(
174+
f"packed_codes device {packed_codes.device} does not match required device {self.device}"
175+
)
159176
if packed_codes.dtype != torch.uint8:
160-
raise ValueError(f"packed_codes dtype {packed_codes.dtype} is not torch.uint8")
177+
raise ValueError(
178+
f"packed_codes dtype {packed_codes.dtype} is not torch.uint8"
179+
)
161180

162181
# e.g., packed_codes.shape=[s1, s2, s3, hash_numbers]
163182
# orig_shape = [s1, s2, s3]
164183
orig_shape = packed_codes.shape[:-1]
165-
184+
166185
# [N * hash_numbers], e.g., N = s1*s2*s3
167186
packed_codes_flat = packed_codes.view(-1)
168187

169-
if self.device.type == 'npu':
188+
if self.device.type == "npu":
170189
# [N * hash_bits]
171-
unpacked_bits_flat = torch_npu.npu_sign_bits_unpack(packed_codes_flat, size=1, dtype=torch.float16)
172-
elif self.device.type == 'cuda' or self.device.type == 'cpu':
190+
unpacked_bits_flat = torch_npu.npu_sign_bits_unpack(
191+
packed_codes_flat, size=1, dtype=torch.float16
192+
)
193+
elif self.device.type == "cuda" or self.device.type == "cpu":
173194
# (TODO) improve performance later on CUDA ops and CPU SIMD instructions
174195
# [N, hash_numbers]
175-
packed_codes_2d = packed_codes_flat.view(-1, self.hash_numbers)
196+
packed_codes_2d = packed_codes_flat.view(-1, self.hash_numbers)
176197

177198
# [N, hash_numbers, 8]
178-
expanded = packed_codes_2d.unsqueeze(-1).expand(-1, -1, 8) # expand last dim to 8
199+
expanded = packed_codes_2d.unsqueeze(-1).expand(
200+
-1, -1, 8
201+
) # expand last dim to 8
179202

180203
# (expanded & self.bit_masks) > 0 -> [N, hash_numbers, 8]
181-
unpacked_bits = (expanded & self.bit_masks) > 0
204+
unpacked_bits = (expanded & self.bit_masks) > 0
182205

183206
# 0 -> -1, 1 -> 1
184-
unpacked_bits = unpacked_bits*2-1
207+
unpacked_bits = unpacked_bits * 2 - 1
185208

186-
unpacked_bits = unpacked_bits.to(torch.float16)
209+
unpacked_bits = unpacked_bits.to(torch.float16)
187210

188211
# [N, hash_bits]
189212
unpacked_bits_flat = unpacked_bits.view(-1, self.hash_bits)
@@ -195,13 +218,14 @@ def _unpack_hash(
195218

196219
return unpacked_bits
197220

221+
198222
if __name__ == "__main__":
199-
if hasattr(torch, 'npu') and torch.npu.is_available():
200-
device=torch.device("npu:0")
201-
elif hasattr(torch, 'cuda') and torch.cuda.is_available():
202-
device=torch.device("cuda:0")
223+
if hasattr(torch, "npu") and torch.npu.is_available():
224+
device = torch.device("npu:0")
225+
elif hasattr(torch, "cuda") and torch.cuda.is_available():
226+
device = torch.device("cuda:0")
203227
else:
204-
device=torch.device("cpu")
228+
device = torch.device("cpu")
205229

206230
print("Using device:", device)
207231

@@ -220,5 +244,9 @@ def _unpack_hash(
220244
print("unpacked_bits:", unpacked_bits)
221245
print("unpacked_bits shape:", unpacked_bits.shape)
222246

223-
print(f"hash_codes[0].item()={hash_codes[0].item()}, 8-bit binary form:{hash_codes[0].item():08b}")
224-
print(f"hash_codes[1].item()={hash_codes[1].item()}, 8-bit binary form:{hash_codes[1].item():08b}")
247+
print(
248+
f"hash_codes[0].item()={hash_codes[0].item()}, 8-bit binary form:{hash_codes[0].item():08b}"
249+
)
250+
print(
251+
f"hash_codes[1].item()={hash_codes[1].item()}, 8-bit binary form:{hash_codes[1].item():08b}"
252+
)

0 commit comments

Comments
 (0)