2626import hashlib
2727import pickle
2828from dataclasses import dataclass , field
29- from typing import TYPE_CHECKING , Any , Generator , List , Optional
29+ from typing import TYPE_CHECKING , Any , Generator , List , Optional , Union
3030
3131import torch
3232from vllm .config import VllmConfig
@@ -88,6 +88,8 @@ class ReqMeta:
8888 load_paras : Optional [LoadPara ] = None
8989 # Save information
9090 save_paras : Optional [SavePara ] = None
91+ # Mark request which need load async
92+ load_async : bool = False
9193
9294
9395@dataclass
@@ -103,13 +105,15 @@ def add_request(
103105 vllm_block_ids : list [int ],
104106 load_paras : Optional [LoadPara ] = None ,
105107 save_paras : Optional [SavePara ] = None ,
108+ load_async : bool = False ,
106109 ) -> None :
107110 self .requests .append (
108111 ReqMeta (
109112 request_id = request_id ,
110113 vllm_block_ids = vllm_block_ids ,
111114 load_paras = load_paras ,
112115 save_paras = save_paras ,
116+ load_async = load_async ,
113117 )
114118 )
115119
@@ -136,6 +140,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
136140 )
137141 self .element_size = vllm_config .model_config .dtype .itemsize
138142 self .kv_role = vllm_config .kv_transfer_config .kv_role
143+ self ._need_load_reqs : dict [str , Union [list [int ], list [Task ]]] = {}
139144 if (
140145 self ._vllm_config .kv_transfer_config is not None
141146 and "ucm_connector_name"
@@ -326,35 +331,41 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
326331 tensors , offsets = self .get_tensor_and_offset_layerwise (
327332 fetch_block_ids , kv_layer , layer_name
328333 )
329- if not self .use_layerwise :
330- task = self .connector .load (
331- fetch_block_hashes , offsets [:blocks_len ], tensors [:blocks_len ]
334+ k_task_id = self .connector .load (
335+ fetch_block_hashes , offsets [:blocks_len ], tensors [:blocks_len ]
336+ )
337+ v_task_id = None
338+ if not self .is_mla :
339+ v_task_id = self .connector .load (
340+ fetch_block_hashes ,
341+ offsets [blocks_len :],
342+ tensors [blocks_len :],
332343 )
333- assert self .connector .wait (task ) == 0
344+ if request .request_id not in self .layerwise_load_tasks :
345+ self .layerwise_load_tasks [request .request_id ] = {}
346+ self .layerwise_load_tasks [request .request_id ][layer_name ] = (
347+ k_task_id ,
348+ v_task_id ,
349+ )
350+
351+ if request .load_async :
352+ for _ , (k_task , v_task ) in self .layerwise_load_tasks [
353+ request .request_id
354+ ].items ():
355+ if request .request_id not in self ._need_load_reqs :
356+ self ._need_load_reqs [request .request_id ] = []
357+ self ._need_load_reqs [request .request_id ].append (k_task )
334358 if not self .is_mla :
335- task = self .connector .load (
336- fetch_block_hashes ,
337- offsets [blocks_len :],
338- tensors [blocks_len :],
339- )
340- assert self .connector .wait (task ) == 0
341- else :
342- k_task_id = self .connector .load (
343- fetch_block_hashes , offsets [:blocks_len ], tensors [:blocks_len ]
344- )
345- v_task_id = None
359+ self ._need_load_reqs [request .request_id ].append (v_task )
360+ continue
361+
362+ if not self .use_layerwise :
363+ for _ , (k_task , v_task ) in self .layerwise_load_tasks [
364+ request .request_id
365+ ].items ():
366+ assert self .connector .wait (k_task ) == 0
346367 if not self .is_mla :
347- v_task_id = self .connector .load (
348- fetch_block_hashes ,
349- offsets [blocks_len :],
350- tensors [blocks_len :],
351- )
352- if request .request_id not in self .layerwise_load_tasks :
353- self .layerwise_load_tasks [request .request_id ] = {}
354- self .layerwise_load_tasks [request .request_id ][layer_name ] = (
355- k_task_id ,
356- v_task_id ,
357- )
368+ assert self .connector .wait (v_task ) == 0
358369
359370 def wait_for_layer_load (self , layer_name : str ) -> None :
360371 """
@@ -415,7 +426,7 @@ def save_kv_layer(
415426 assert attn_metadata is not None , "The attn_metadata should not be None."
416427
417428 for request in metadata .requests :
418- if request .save_paras is None :
429+ if request .save_paras is None or request . load_async :
419430 continue
420431
421432 save_param = request .save_paras
@@ -535,6 +546,30 @@ def wait_for_tasks():
535546 self .dump_tasks .clear ()
536547 return success_dumped_blocks if success_dumped_blocks else None
537548
549+ def get_finished (self , finished_req_ids : set [str ]) -> tuple [set [str ], set [str ]]:
550+ """Get the finished recving and sending requests."""
551+ done_recving : set [str ] = set ()
552+ for req_id , tasks in self ._need_load_reqs .items ():
553+ unfinished_tasks = []
554+ for task in tasks :
555+ ret = self .connector .check (task )
556+ if ret == 0 :
557+ # Remove this assertion after support recompute load failed reqs
558+ assert self .connector .wait (task ) == 0
559+ continue
560+ if ret != - 1 :
561+ raise ValueError (f"Task { task .get_id ()} Not Found" )
562+ unfinished_tasks .append (task )
563+ if not unfinished_tasks :
564+ done_recving .add (req_id )
565+ self ._need_load_reqs [req_id ] = unfinished_tasks
566+
567+ # remove the finished requests
568+ for req_id in list (done_recving ):
569+ self ._need_load_reqs .pop (req_id , None )
570+
571+ return None , done_recving
572+
538573 # ==============================
539574 # Scheduler-side methods
540575 # ==============================
@@ -588,6 +623,12 @@ def md5(input) -> int:
588623 can_load = False ,
589624 )
590625
626+ need_load_tokens = max (num_external_computed_tokens - num_computed_tokens , 0 )
627+ if hasattr (self , "kv_role" ) and self .kv_role == "kv_consumer" :
628+ if need_load_tokens > 0 :
629+ self ._need_load_reqs [request .request_id ] = []
630+ return need_load_tokens , True
631+
591632 num_max_cached_tokens = max (num_external_computed_tokens , num_computed_tokens )
592633 num_blocks_need_save = (
593634 len (request .all_token_ids ) - num_max_cached_tokens
@@ -608,7 +649,7 @@ def md5(input) -> int:
608649 f"num_computed_tokens = { num_computed_tokens } .\n "
609650 )
610651
611- return max ( num_external_computed_tokens - num_computed_tokens , 0 ) , False
652+ return need_load_tokens , False
612653
613654 def update_state_after_alloc (
614655 self , request : "Request" , blocks : "KVCacheBlocks" , num_external_tokens : int
@@ -623,6 +664,12 @@ def update_state_after_alloc(
623664 if num_external_tokens > 0 :
624665 self .load_paras [request .request_id ].can_load = True
625666
667+ if request .request_id in self ._need_load_reqs :
668+ local_block_ids = (
669+ blocks .get_unhashed_block_ids () if num_external_tokens > 0 else []
670+ )
671+ self ._need_load_reqs [request .request_id ] = local_block_ids
672+
626673 def build_connector_meta (
627674 self , scheduler_output : SchedulerOutput
628675 ) -> KVConnectorMetadata :
@@ -636,6 +683,16 @@ def build_connector_meta(
636683 scheduler_output (SchedulerOutput): the scheduler output object.
637684 """
638685 meta = UCConnectorV1Metadata ()
686+
687+ for req_id , block_ids in self ._need_load_reqs .items ():
688+ meta .add_request (
689+ req_id ,
690+ vllm_block_ids = block_ids ,
691+ load_paras = self .load_paras [req_id ],
692+ load_async = True ,
693+ )
694+ self ._need_load_reqs .clear ()
695+
639696 for new_req in scheduler_output .scheduled_new_reqs :
640697 # Load kv is only supported for new reqs
641698 new_scheduled_blocks = (
0 commit comments