2323"""
2424
2525import torch
26- if hasattr (torch , 'npu' ) and torch .npu .is_available ():
26+
27+ if hasattr (torch , "npu" ) and torch .npu .is_available ():
2728 import torch_npu
2829
2930from ucm .logger import init_logger
31+
3032logger = init_logger (__name__ )
3133
34+
3235class HashEncoder :
3336 """
3437 HashEncoder converts a float tensor to a binary hash code tensor,
3538 and it packs every 8 bits into a uint8 number.
3639 """
40+
3741 def __init__ (
38- self ,
39- input_dim : int ,
40- hash_bits : int ,
41- dtype : torch .dtype ,
42- device : torch .device
42+ self , input_dim : int , hash_bits : int , dtype : torch .dtype , device : torch .device
4343 ) -> None :
4444 self .input_dim = input_dim
4545
@@ -49,56 +49,70 @@ def __init__(
4949 self .hash_bits = hash_bits
5050
5151 # number of uint8 numbers to store hash_bits bits
52- self .hash_numbers = self .hash_bits // 8
52+ self .hash_numbers = self .hash_bits // 8
5353
5454 self .dtype = dtype
5555 self .device = device
5656
57- if self .device .type == ' npu' :
57+ if self .device .type == " npu" :
5858 if dtype not in [torch .float16 , torch .float32 , torch .float64 ]:
59- logger .warning ("NPU only supports float16, float32 and float64 for hash_weights" )
59+ logger .warning (
60+ "NPU only supports float16, float32 and float64 for hash_weights"
61+ )
6062 logger .warning ("automatically using float16 for hash_weights now" )
61- self .dtype = torch .float16
62-
63- self .hash_weights = torch .normal (mean = 0 , std = 2 , size = (self .input_dim , self .hash_bits ), dtype = self .dtype , device = self .device )
63+ self .dtype = torch .float16
6464
65- if self .device .type == 'cuda' or self .device .type == 'cpu' :
65+ self .hash_weights = torch .normal (
66+ mean = 0 ,
67+ std = 2 ,
68+ size = (self .input_dim , self .hash_bits ),
69+ dtype = self .dtype ,
70+ device = self .device ,
71+ )
72+
73+ if self .device .type == "cuda" or self .device .type == "cpu" :
6674 self ._init_bit_masks ()
67-
68- def set_hash_weight (
69- self ,
70- hash_weights : torch .Tensor
71- ) -> None :
75+
76+ def set_hash_weight (self , hash_weights : torch .Tensor ) -> None :
7277 if hash_weights .shape != (self .input_dim , self .hash_bits ):
73- raise ValueError (f"hash_weights shape { hash_weights .shape } does not match required shape { (self .input_dim , self .hash_bits )} " )
78+ raise ValueError (
79+ f"hash_weights shape { hash_weights .shape } does not match required shape { (self .input_dim , self .hash_bits )} "
80+ )
7481 if hash_weights .dtype != self .dtype :
75- raise ValueError (f"hash_weights dtype { hash_weights .dtype } does not match required dtype { self .dtype } " )
82+ raise ValueError (
83+ f"hash_weights dtype { hash_weights .dtype } does not match required dtype { self .dtype } "
84+ )
7685 if hash_weights .device != self .device :
77- raise ValueError (f"hash_weights device { hash_weights .device } does not match required device { self .device } " )
78-
86+ raise ValueError (
87+ f"hash_weights device { hash_weights .device } does not match required device { self .device } "
88+ )
89+
7990 self .hash_weights .copy_ (hash_weights )
80-
91+
8192 def _init_bit_masks (self ) -> None :
82- self .bit_masks = torch .pow (2 , torch .arange (8 , dtype = torch .uint8 , device = self .device ))
93+ self .bit_masks = torch .pow (
94+ 2 , torch .arange (8 , dtype = torch .uint8 , device = self .device )
95+ )
8396 # shape (1, 1, 8)
84- self .bit_masks = self .bit_masks .unsqueeze (0 ).unsqueeze (0 )
85-
86- def compute_hash (
87- self ,
88- x : torch .Tensor
89- ) -> torch .Tensor :
97+ self .bit_masks = self .bit_masks .unsqueeze (0 ).unsqueeze (0 )
98+
99+ def compute_hash (self , x : torch .Tensor ) -> torch .Tensor :
90100 """
91101 Compute the hash code for input tensor x.
92- Args:
102+ Args:
93103 x: input tensor of shape (..., input_dim)
94104 Returns:
95- A tensor of shape (..., hash_numbers=hash_bits // 8) representing the hash codes.
105+ A tensor of shape (..., hash_numbers=hash_bits // 8) representing the hash codes.
96106 Each element is a uint8 number representing 8 bits of the hash code.
97107 """
98108 if x .shape [- 1 ] != self .input_dim :
99- raise ValueError (f"x must be of shape (..., { self .input_dim } ), but got { x .shape } " )
109+ raise ValueError (
110+ f"x must be of shape (..., { self .input_dim } ), but got { x .shape } "
111+ )
100112 if x .device != self .device :
101- raise ValueError (f"x device { x .device } does not match required device { self .device } " )
113+ raise ValueError (
114+ f"x device { x .device } does not match required device { self .device } "
115+ )
102116
103117 # original shape without the last dimension
104118 # e.g. x.shape=[s1,s2,s3,input_dim], orig_shape=[s1,s2,s3]
@@ -112,78 +126,87 @@ def compute_hash(
112126
113127 # [N, hash_bits]
114128 xW = torch .matmul (x_flat , self .hash_weights )
115-
129+
116130 # [N * hash_bits]
117131 xW_flat = xW .view (- 1 )
118132
119- if self .device .type == ' npu' :
133+ if self .device .type == " npu" :
120134 # [N*hash_numbers], where hash_numbers = hash_bits // 8
121135 packed_codes_flat = torch_npu .npu_sign_bits_pack (xW_flat , size = 1 )
122- elif self .device .type == ' cuda' or self .device .type == ' cpu' :
136+ elif self .device .type == " cuda" or self .device .type == " cpu" :
123137 # (TODO) improve performance later on CUDA ops and CPU SIMD instructions
124138 # [N, hash_bits]
125139 projected = (xW > 0 ).to (torch .uint8 )
126140
127141 # [N, hash_numbers, 8]
128- binary_codes = projected .view (- 1 , self .hash_numbers , 8 )
142+ binary_codes = projected .view (- 1 , self .hash_numbers , 8 )
129143
130144 # binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8]
131145 # then sum along the last dimension to get [N, hash_numbers]
132- packed_codes_flat = torch .sum (binary_codes * self .bit_masks , dim = - 1 , dtype = torch .uint8 ) # [N, hash_numbers]
146+ packed_codes_flat = torch .sum (
147+ binary_codes * self .bit_masks , dim = - 1 , dtype = torch .uint8
148+ ) # [N, hash_numbers]
133149 packed_codes_flat = packed_codes_flat .view (- 1 ) # [N * hash_numbers]
134150 else :
135151 raise ValueError (f"Unsupported device type: { self .device .type } " )
136152
137153 # e.g., [s1, s2, s3, hash_numbers]
138154 out_shape = orig_shape + (self .hash_numbers ,)
139- packed_codes = packed_codes_flat .view (out_shape )
155+ packed_codes = packed_codes_flat .view (out_shape )
140156
141157 return packed_codes
142158
143- def _unpack_hash (
144- self ,
145- packed_codes : torch .Tensor
146- ) -> torch .Tensor :
159+ def _unpack_hash (self , packed_codes : torch .Tensor ) -> torch .Tensor :
147160 """
148161 Unpack the hash codes to +1 or -1 bits.
149162 Args:
150163 packed_codes: input tensor of shape (..., hash_numbers), dtype=torch.uint8
151164 Returns:
152- A tensor of shape (..., hash_bits=hash_numbers*8) representing the unpacked bits.
165+ A tensor of shape (..., hash_bits=hash_numbers*8) representing the unpacked bits.
153166 Each element is either -1 or 1.
154167 """
155168 if packed_codes .shape [- 1 ] != self .hash_numbers :
156- raise ValueError (f"packed_codes must be of shape (..., { self .hash_numbers } ), but got { packed_codes .shape } " )
169+ raise ValueError (
170+ f"packed_codes must be of shape (..., { self .hash_numbers } ), but got { packed_codes .shape } "
171+ )
157172 if packed_codes .device != self .device :
158- raise ValueError (f"packed_codes device { packed_codes .device } does not match required device { self .device } " )
173+ raise ValueError (
174+ f"packed_codes device { packed_codes .device } does not match required device { self .device } "
175+ )
159176 if packed_codes .dtype != torch .uint8 :
160- raise ValueError (f"packed_codes dtype { packed_codes .dtype } is not torch.uint8" )
177+ raise ValueError (
178+ f"packed_codes dtype { packed_codes .dtype } is not torch.uint8"
179+ )
161180
162181 # e.g., packed_codes.shape=[s1, s2, s3, hash_numbers]
163182 # orig_shape = [s1, s2, s3]
164183 orig_shape = packed_codes .shape [:- 1 ]
165-
184+
166185 # [N * hash_numbers], e.g., N = s1*s2*s3
167186 packed_codes_flat = packed_codes .view (- 1 )
168187
169- if self .device .type == ' npu' :
188+ if self .device .type == " npu" :
170189 # [N * hash_bits]
171- unpacked_bits_flat = torch_npu .npu_sign_bits_unpack (packed_codes_flat , size = 1 , dtype = torch .float16 )
172- elif self .device .type == 'cuda' or self .device .type == 'cpu' :
190+ unpacked_bits_flat = torch_npu .npu_sign_bits_unpack (
191+ packed_codes_flat , size = 1 , dtype = torch .float16
192+ )
193+ elif self .device .type == "cuda" or self .device .type == "cpu" :
173194 # (TODO) improve performance later on CUDA ops and CPU SIMD instructions
174195 # [N, hash_numbers]
175- packed_codes_2d = packed_codes_flat .view (- 1 , self .hash_numbers )
196+ packed_codes_2d = packed_codes_flat .view (- 1 , self .hash_numbers )
176197
177198 # [N, hash_numbers, 8]
178- expanded = packed_codes_2d .unsqueeze (- 1 ).expand (- 1 , - 1 , 8 ) # expand last dim to 8
199+ expanded = packed_codes_2d .unsqueeze (- 1 ).expand (
200+ - 1 , - 1 , 8
201+ ) # expand last dim to 8
179202
180203 # (expanded & self.bit_masks) > 0 -> [N, hash_numbers, 8]
181- unpacked_bits = (expanded & self .bit_masks ) > 0
204+ unpacked_bits = (expanded & self .bit_masks ) > 0
182205
183206 # 0 -> -1, 1 -> 1
184- unpacked_bits = unpacked_bits * 2 - 1
207+ unpacked_bits = unpacked_bits * 2 - 1
185208
186- unpacked_bits = unpacked_bits .to (torch .float16 )
209+ unpacked_bits = unpacked_bits .to (torch .float16 )
187210
188211 # [N, hash_bits]
189212 unpacked_bits_flat = unpacked_bits .view (- 1 , self .hash_bits )
@@ -195,13 +218,14 @@ def _unpack_hash(
195218
196219 return unpacked_bits
197220
221+
198222if __name__ == "__main__" :
199- if hasattr (torch , ' npu' ) and torch .npu .is_available ():
200- device = torch .device ("npu:0" )
201- elif hasattr (torch , ' cuda' ) and torch .cuda .is_available ():
202- device = torch .device ("cuda:0" )
223+ if hasattr (torch , " npu" ) and torch .npu .is_available ():
224+ device = torch .device ("npu:0" )
225+ elif hasattr (torch , " cuda" ) and torch .cuda .is_available ():
226+ device = torch .device ("cuda:0" )
203227 else :
204- device = torch .device ("cpu" )
228+ device = torch .device ("cpu" )
205229
206230 print ("Using device:" , device )
207231
@@ -220,5 +244,9 @@ def _unpack_hash(
220244 print ("unpacked_bits:" , unpacked_bits )
221245 print ("unpacked_bits shape:" , unpacked_bits .shape )
222246
223- print (f"hash_codes[0].item()={ hash_codes [0 ].item ()} , 8-bit binary form:{ hash_codes [0 ].item ():08b} " )
224- print (f"hash_codes[1].item()={ hash_codes [1 ].item ()} , 8-bit binary form:{ hash_codes [1 ].item ():08b} " )
247+ print (
248+ f"hash_codes[0].item()={ hash_codes [0 ].item ()} , 8-bit binary form:{ hash_codes [0 ].item ():08b} "
249+ )
250+ print (
251+ f"hash_codes[1].item()={ hash_codes [1 ].item ()} , 8-bit binary form:{ hash_codes [1 ].item ():08b} "
252+ )
0 commit comments