@@ -141,6 +141,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
141141 self .element_size = vllm_config .model_config .dtype .itemsize
142142 self .kv_role = vllm_config .kv_transfer_config .kv_role
143143 self ._need_load_reqs : dict [str , Union [list [int ], list [Task ]]] = {}
144+ self ._load_failed_reqs : set [str ] = set ()
145+ self ._load_req_to_blocks : dict [str , set [int ]] = {}
144146 if (
145147 self ._vllm_config .kv_transfer_config is not None
146148 and "ucm_connector_name"
@@ -280,6 +282,16 @@ def load(tensor_list, offset_list) -> tuple[Task, Task]:
280282 # ==============================
281283 # Worker-side methods
282284 # ==============================
285+ def clear_connector_metadata (self ) -> None :
286+ """Clear the connector metadata.
287+
288+ This function should be called by the model runner every time
289+ after the model execution.
290+ """
291+ self ._load_failed_reqs .clear ()
292+ self ._load_req_to_blocks .clear ()
293+ super ().clear_connector_metadata ()
294+
283295 def start_load_kv (self , forward_context : "ForwardContext" , ** kwargs ) -> None :
284296 """
285297 Start loading the KV cache from the connector to vLLM's paged
@@ -306,7 +318,6 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
306318 for request in metadata .requests :
307319 if request .load_paras is None or not request .load_paras .can_load :
308320 continue
309- layer_to_tensor : dict [str , tuple [List [torch .Tensor ], List [int ]]] = {}
310321 block_ids = request .vllm_block_ids
311322 # Blocks id need to save should start after last vllm cached block
312323 load_start_block_id = (
@@ -327,6 +338,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
327338 ]
328339 assert len (fetch_block_ids ) == len (fetch_block_hashes )
329340 blocks_len = len (fetch_block_ids )
341+ self ._load_req_to_blocks .setdefault (request .request_id , set ()).update (
342+ fetch_block_ids
343+ )
330344 for layer_name , kv_layer in self .kv_caches .items ():
331345 tensors , offsets = self .get_tensor_and_offset_layerwise (
332346 fetch_block_ids , kv_layer , layer_name
@@ -363,9 +377,12 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
363377 for _ , (k_task , v_task ) in self .layerwise_load_tasks [
364378 request .request_id
365379 ].items ():
366- assert self .connector .wait (k_task ) == 0
367- if not self .is_mla :
368- assert self .connector .wait (v_task ) == 0
380+ if self .connector .wait (k_task ) != 0 :
381+ self ._load_failed_reqs .add (request .request_id )
382+ break
383+ if v_task and self .connector .wait (v_task ) != 0 :
384+ self ._load_failed_reqs .add (request .request_id )
385+ break
369386
370387 def wait_for_layer_load (self , layer_name : str ) -> None :
371388 """
@@ -387,10 +404,16 @@ def wait_for_layer_load(self, layer_name: str) -> None:
387404 self .current_layer < self .num_layers
388405 ), "The current layer should be less than total layers!"
389406 for request_id , layer_to_task in self .layerwise_load_tasks .items ():
407+ if request_id in self ._load_failed_reqs :
408+ continue
390409 k_task , v_task = layer_to_task [layer_name ]
391- assert self .connector .wait (k_task ) == 0
410+ if self .connector .wait (k_task ) != 0 :
411+ self ._load_failed_reqs .add (request_id )
412+ continue
392413 if not self .is_mla :
393- assert self .connector .wait (v_task ) == 0
414+ if self .connector .wait (v_task ) != 0 :
415+ self ._load_failed_reqs .add (request_id )
416+ continue
394417 logger .debug (f"Load tasks for { request_id } on layer { layer_name } finished." )
395418
396419 def save_kv_layer (
@@ -550,16 +573,18 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
550573 """Get the finished recving and sending requests."""
551574 done_recving : set [str ] = set ()
552575 for req_id , tasks in self ._need_load_reqs .items ():
576+ if req_id in self ._load_failed_reqs :
577+ continue
553578 unfinished_tasks = []
554579 for task in tasks :
555580 ret = self .connector .check (task )
556- if ret == 0 :
557- # Remove this assertion after support recompute load failed reqs
558- assert self .connector .wait (task ) == 0
581+ if ret == - 1 :
582+ unfinished_tasks .append (task )
583+ continue
584+ elif ret == 0 and self .connector .wait (task ) == 0 :
559585 continue
560- if ret != - 1 :
561- raise ValueError (f"Task { task .get_id ()} Not Found" )
562- unfinished_tasks .append (task )
586+ self ._load_failed_reqs .add (req_id )
587+ break
563588 if not unfinished_tasks :
564589 done_recving .add (req_id )
565590 self ._need_load_reqs [req_id ] = unfinished_tasks
@@ -624,10 +649,19 @@ def md5(input) -> int:
624649 )
625650
626651 need_load_tokens = max (num_external_computed_tokens - num_computed_tokens , 0 )
652+ # Load async when Decode instance need to load.
627653 if hasattr (self , "kv_role" ) and self .kv_role == "kv_consumer" :
654+ # Only trigger 1 asynchronous KV transfer per request.
655+ if (
656+ request .kv_transfer_params
657+ and request .kv_transfer_params ["load_async" ] == False
658+ ):
659+ return 0 , False
660+ request .kv_transfer_params = request .kv_transfer_params or {}
661+ request .kv_transfer_params ["load_async" ] = False
628662 if need_load_tokens > 0 :
629663 self ._need_load_reqs [request .request_id ] = []
630- return need_load_tokens , True
664+ return need_load_tokens , True
631665
632666 num_max_cached_tokens = max (num_external_computed_tokens , num_computed_tokens )
633667 num_blocks_need_save = (
@@ -778,6 +812,13 @@ def request_finished(
778812 self .connector .commit (cancel_blocks , False )
779813 return False , None
780814
815+ def get_block_ids_with_load_errors (self ) -> set [int ]:
816+ invalid_block_ids : set [int ] = set ()
817+ for req_id in self ._load_failed_reqs :
818+ if req_id in self ._load_req_to_blocks :
819+ invalid_block_ids .update (self ._load_req_to_blocks [req_id ])
820+ return invalid_block_ids
821+
781822 @staticmethod
782823 def _extract_layer_index (layer_name : str ) -> Optional [int ]:
783824 """
0 commit comments