@@ -109,6 +109,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
109109 self ._need_load_reqs : dict [str , Union [list [int ], list [Task ]]] = {}
110110 self ._load_failed_reqs : set [str ] = set ()
111111 self ._load_req_to_blocks : dict [str , set [int ]] = {}
112+ if role == KVConnectorRole .WORKER :
113+ self ._initialize_dataoffset (vllm_config )
112114 if (
113115 self ._vllm_config .kv_transfer_config is not None
114116 and "ucm_connector_name"
@@ -156,37 +158,34 @@ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"
156158 forward_context .virtual_engine
157159 ]
158160
159- def DataOffset (self , kv_layer , rank , layer_id , is_v ):
160- # Non-MLA scene: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
161- # MLA scene: one layer shape is (num_blocks, block_size, head_size)
162- # Element size
163- elem_size = kv_layer [ 0 ]. element_size ()
164- logger . debug (
165- f"total_tp_size = { self .total_tp_size } , \n " f"element size = { elem_size } ."
161+ def _initialize_dataoffset (self , vllm_config : "VllmConfig" ):
162+ num_kv_heads = vllm_config . model_config . get_num_kv_heads (
163+ vllm_config . parallel_config
164+ )
165+ head_size = vllm_config . model_config . get_head_size ()
166+ self . min_block_size = (
167+ self .block_size * num_kv_heads * head_size * self . element_size
166168 )
167- # One block size
168- k_min_data_block_size = (
169- kv_layer [0 ][0 ].numel () if not self .is_mla else kv_layer [0 ].numel ()
170- ) * elem_size
171- v_min_data_block_size = (
172- kv_layer [1 ][0 ].numel () if not self .is_mla else 0
173- ) * elem_size
174- # When tp > 1 layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size
175169 layer_size = (
176- k_min_data_block_size + v_min_data_block_size
177- ) * self .total_tp_size
178- if is_v :
179- # Offset of v = Offset of k + k_min_data_block_size
180- return int (
181- self .DataOffset (kv_layer , rank , layer_id , False ) + k_min_data_block_size
182- )
183- if self .is_mla :
184- return int (layer_size * layer_id )
185- else :
186- # Offset of k = layer_size * layer_id + layer_size / tp_size * current rank
187- return int (
188- layer_size * layer_id + layer_size / self .total_tp_size * self .rank
189- )
170+ self .min_block_size * 2 * self .total_tp_size
171+ if not self .is_mla
172+ else self .min_block_size
173+ )
174+ # layer_id -> rank -> k_offset
175+ self .k_data_offsets : dict [int , dict [int , int ]] = {}
176+
177+ for layer_id in range (self .num_layers ):
178+ self .k_data_offsets [layer_id ] = {}
179+ for rank in range (self .total_tp_size ):
180+ if self .is_mla :
181+ self .k_data_offsets [layer_id ][0 ] = layer_size * layer_id
182+ break
183+ else :
184+ offset = (
185+ layer_size * layer_id
186+ + (layer_size // self .total_tp_size ) * rank
187+ )
188+ self .k_data_offsets [layer_id ][rank ] = offset
190189
191190 def get_tensor_and_offset_layerwise (
192191 self , vllm_block_ids : List [int ], kv_layer : torch .Tensor , layer_name : str
@@ -198,14 +197,17 @@ def get_tensor_and_offset_layerwise(
198197 layer_id = self ._extract_layer_index (layer_name )
199198
200199 for blk_id in vllm_block_ids :
201- k_data_offset = self .DataOffset (kv_layer , self .rank , layer_id , False )
202200 if self .is_mla :
201+ k_data_offset = self .k_data_offsets [layer_id ][0 ]
203202 k_tensors .append (kv_layer [blk_id ])
204203 else :
204+ k_data_offset = self .k_data_offsets [layer_id ][self .rank ]
205205 k_tensors .append (kv_layer [0 ][blk_id ])
206206 k_offsets .append (k_data_offset )
207207 if not self .is_mla :
208- v_data_offset = self .DataOffset (kv_layer , self .rank , layer_id , True )
208+ v_data_offset = (
209+ self .k_data_offsets [layer_id ][self .rank ] + self .min_block_size
210+ )
209211 v_tensors .append (kv_layer [1 ][blk_id ])
210212 v_offsets .append (v_data_offset )
211213 return k_tensors + v_tensors , k_offsets + v_offsets
0 commit comments