Skip to content

Commit 81f84e2

Browse files
flesher0813y00945504
andauthored
[Feat] Support load async (#166)
* [Feat] support load_async * [Feat] Support asynchronous loading of KV cache Signed-off-by: y00945504 <yuhui87@huawei.com> * [Fix bug] adapt unittest and continue to load next req when finish dealing with load_async req Signed-off-by: flesher0813 <1208954694@qq.com> * [Fix bug] func check return value not match Signed-off-by: flesher0813 <1208954694@qq.com> --------- Signed-off-by: y00945504 <yuhui87@huawei.com> Signed-off-by: flesher0813 <1208954694@qq.com> Co-authored-by: y00945504 <yuhui87@huawei.com>
1 parent bf081fb commit 81f84e2

File tree

6 files changed

+143
-29
lines changed

6 files changed

+143
-29
lines changed

test/test_uc_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def init_uc(
109109
ucconnector.layerwise_load_tasks: dict[
110110
str, dict[str, tuple[Task, Task]]
111111
] = {}
112+
ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
112113
return ucconnector
113114

114115
def test_get_num_new_matched_tokens_hit(self):

ucm/integration/vllm/uc_connector.py

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import hashlib
2727
import pickle
2828
from dataclasses import dataclass, field
29-
from typing import TYPE_CHECKING, Any, Generator, List, Optional
29+
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union
3030

3131
import torch
3232
from vllm.config import VllmConfig
@@ -88,6 +88,8 @@ class ReqMeta:
8888
load_paras: Optional[LoadPara] = None
8989
# Save information
9090
save_paras: Optional[SavePara] = None
91+
# Mark request which need load async
92+
load_async: bool = False
9193

9294

9395
@dataclass
@@ -103,13 +105,15 @@ def add_request(
103105
vllm_block_ids: list[int],
104106
load_paras: Optional[LoadPara] = None,
105107
save_paras: Optional[SavePara] = None,
108+
load_async: bool = False,
106109
) -> None:
107110
self.requests.append(
108111
ReqMeta(
109112
request_id=request_id,
110113
vllm_block_ids=vllm_block_ids,
111114
load_paras=load_paras,
112115
save_paras=save_paras,
116+
load_async=load_async,
113117
)
114118
)
115119

@@ -136,6 +140,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
136140
)
137141
self.element_size = vllm_config.model_config.dtype.itemsize
138142
self.kv_role = vllm_config.kv_transfer_config.kv_role
143+
self._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
139144
if (
140145
self._vllm_config.kv_transfer_config is not None
141146
and "ucm_connector_name"
@@ -326,35 +331,41 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
326331
tensors, offsets = self.get_tensor_and_offset_layerwise(
327332
fetch_block_ids, kv_layer, layer_name
328333
)
329-
if not self.use_layerwise:
330-
task = self.connector.load(
331-
fetch_block_hashes, offsets[:blocks_len], tensors[:blocks_len]
334+
k_task_id = self.connector.load(
335+
fetch_block_hashes, offsets[:blocks_len], tensors[:blocks_len]
336+
)
337+
v_task_id = None
338+
if not self.is_mla:
339+
v_task_id = self.connector.load(
340+
fetch_block_hashes,
341+
offsets[blocks_len:],
342+
tensors[blocks_len:],
332343
)
333-
assert self.connector.wait(task) == 0
344+
if request.request_id not in self.layerwise_load_tasks:
345+
self.layerwise_load_tasks[request.request_id] = {}
346+
self.layerwise_load_tasks[request.request_id][layer_name] = (
347+
k_task_id,
348+
v_task_id,
349+
)
350+
351+
if request.load_async:
352+
for _, (k_task, v_task) in self.layerwise_load_tasks[
353+
request.request_id
354+
].items():
355+
if request.request_id not in self._need_load_reqs:
356+
self._need_load_reqs[request.request_id] = []
357+
self._need_load_reqs[request.request_id].append(k_task)
334358
if not self.is_mla:
335-
task = self.connector.load(
336-
fetch_block_hashes,
337-
offsets[blocks_len:],
338-
tensors[blocks_len:],
339-
)
340-
assert self.connector.wait(task) == 0
341-
else:
342-
k_task_id = self.connector.load(
343-
fetch_block_hashes, offsets[:blocks_len], tensors[:blocks_len]
344-
)
345-
v_task_id = None
359+
self._need_load_reqs[request.request_id].append(v_task)
360+
continue
361+
362+
if not self.use_layerwise:
363+
for _, (k_task, v_task) in self.layerwise_load_tasks[
364+
request.request_id
365+
].items():
366+
assert self.connector.wait(k_task) == 0
346367
if not self.is_mla:
347-
v_task_id = self.connector.load(
348-
fetch_block_hashes,
349-
offsets[blocks_len:],
350-
tensors[blocks_len:],
351-
)
352-
if request.request_id not in self.layerwise_load_tasks:
353-
self.layerwise_load_tasks[request.request_id] = {}
354-
self.layerwise_load_tasks[request.request_id][layer_name] = (
355-
k_task_id,
356-
v_task_id,
357-
)
368+
assert self.connector.wait(v_task) == 0
358369

359370
def wait_for_layer_load(self, layer_name: str) -> None:
360371
"""
@@ -415,7 +426,7 @@ def save_kv_layer(
415426
assert attn_metadata is not None, "The attn_metadata should not be None."
416427

417428
for request in metadata.requests:
418-
if request.save_paras is None:
429+
if request.save_paras is None or request.load_async:
419430
continue
420431

421432
save_param = request.save_paras
@@ -535,6 +546,30 @@ def wait_for_tasks():
535546
self.dump_tasks.clear()
536547
return success_dumped_blocks if success_dumped_blocks else None
537548

549+
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
550+
"""Get the finished recving and sending requests."""
551+
done_recving: set[str] = set()
552+
for req_id, tasks in self._need_load_reqs.items():
553+
unfinished_tasks = []
554+
for task in tasks:
555+
ret = self.connector.check(task)
556+
if ret == 0:
557+
# Remove this assertion after support recompute load failed reqs
558+
assert self.connector.wait(task) == 0
559+
continue
560+
if ret != -1:
561+
raise ValueError(f"Task {task.get_id()} Not Found")
562+
unfinished_tasks.append(task)
563+
if not unfinished_tasks:
564+
done_recving.add(req_id)
565+
self._need_load_reqs[req_id] = unfinished_tasks
566+
567+
# remove the finished requests
568+
for req_id in list(done_recving):
569+
self._need_load_reqs.pop(req_id, None)
570+
571+
return None, done_recving
572+
538573
# ==============================
539574
# Scheduler-side methods
540575
# ==============================
@@ -588,6 +623,12 @@ def md5(input) -> int:
588623
can_load=False,
589624
)
590625

626+
need_load_tokens = max(num_external_computed_tokens - num_computed_tokens, 0)
627+
if hasattr(self, "kv_role") and self.kv_role == "kv_consumer":
628+
if need_load_tokens > 0:
629+
self._need_load_reqs[request.request_id] = []
630+
return need_load_tokens, True
631+
591632
num_max_cached_tokens = max(num_external_computed_tokens, num_computed_tokens)
592633
num_blocks_need_save = (
593634
len(request.all_token_ids) - num_max_cached_tokens
@@ -608,7 +649,7 @@ def md5(input) -> int:
608649
f"num_computed_tokens = {num_computed_tokens}.\n"
609650
)
610651

611-
return max(num_external_computed_tokens - num_computed_tokens, 0), False
652+
return need_load_tokens, False
612653

613654
def update_state_after_alloc(
614655
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
@@ -623,6 +664,12 @@ def update_state_after_alloc(
623664
if num_external_tokens > 0:
624665
self.load_paras[request.request_id].can_load = True
625666

667+
if request.request_id in self._need_load_reqs:
668+
local_block_ids = (
669+
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
670+
)
671+
self._need_load_reqs[request.request_id] = local_block_ids
672+
626673
def build_connector_meta(
627674
self, scheduler_output: SchedulerOutput
628675
) -> KVConnectorMetadata:
@@ -636,6 +683,16 @@ def build_connector_meta(
636683
scheduler_output (SchedulerOutput): the scheduler output object.
637684
"""
638685
meta = UCConnectorV1Metadata()
686+
687+
for req_id, block_ids in self._need_load_reqs.items():
688+
meta.add_request(
689+
req_id,
690+
vllm_block_ids=block_ids,
691+
load_paras=self.load_paras[req_id],
692+
load_async=True,
693+
)
694+
self._need_load_reqs.clear()
695+
639696
for new_req in scheduler_output.scheduled_new_reqs:
640697
# Load kv is only supported for new reqs
641698
new_scheduled_blocks = (

ucm/store/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,16 @@ def commit(self, block_ids: List[str], is_success: bool = True) -> None:
137137
is_success(bool): if False, we need release block
138138
"""
139139
pass
140+
141+
@abstractmethod
142+
def check(self, task: Task) -> int:
143+
"""
144+
check if kv transfer task finished.
145+
146+
Args:
147+
task (Task): transfer engine task.
148+
Returns:
149+
0 - finished
150+
others - in process.
151+
"""
152+
pass

ucm/store/ucm_dram.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,15 @@ def commit(self, block_ids: List[str], is_success: bool = True) -> None:
187187
"""
188188
if is_success:
189189
self.cached_blocks.update(block_ids)
190+
191+
def check(self, task: Task) -> int:
192+
"""
193+
check if kv transfer task finished.
194+
195+
Args:
196+
task (Task): transfer engine task.
197+
Returns:
198+
0 - finished
199+
others - in process.
200+
"""
201+
pass

ucm/store/ucm_mooncake.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,15 @@ def shutdown(self):
331331
self.loop.close()
332332

333333
self.store.close()
334+
335+
def check(self, task: Task) -> int:
336+
"""
337+
check if kv transfer task finished.
338+
339+
Args:
340+
task (Task): transfer engine task.
341+
Returns:
342+
0 - finished
343+
others - in process.
344+
"""
345+
pass

ucm/store/ucm_nfs_store.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,22 @@ def commit(self, block_ids: List[str], is_success: bool = True) -> None:
185185
logger.warning(f"commit {block_ids} to {is_success}")
186186
ucmnfsstore.CommitBatch(block_ids, is_success)
187187
logger.debug("Succeed in committing kv cache.")
188+
189+
def check(self, task: Task) -> int:
190+
"""
191+
check if kv transfer task finished.
192+
193+
Args:
194+
task (Task): transfer engine task.
195+
Returns:
196+
0 - finished
197+
-1 - in process
198+
others - errors.
199+
"""
200+
ret, finish = ucmnfsstore.Check(task.get_id())
201+
if ret == 0 and finish == True:
202+
return 0
203+
elif ret == 0:
204+
return -1
205+
logger.error(f"check {task.get_id()} return {ret}")
206+
return ret

0 commit comments

Comments
 (0)