Skip to content

Commit 3d6c86e

Browse files
authored
[Feat] support load_async for load failure and adjust patches (#165)
Signed-off-by: flesher0813 <1208954694@qq.com>
1 parent 81f84e2 commit 3d6c86e

File tree

4 files changed

+586
-89
lines changed

4 files changed

+586
-89
lines changed

test/test_uc_connector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def init_uc(
110110
str, dict[str, tuple[Task, Task]]
111111
] = {}
112112
ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
113+
ucconnector._load_failed_reqs: set[str] = set()
114+
ucconnector._load_req_to_blocks: dict[str, set[int]] = {}
113115
return ucconnector
114116

115117
def test_get_num_new_matched_tokens_hit(self):

ucm/integration/vllm/uc_connector.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
141141
self.element_size = vllm_config.model_config.dtype.itemsize
142142
self.kv_role = vllm_config.kv_transfer_config.kv_role
143143
self._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
144+
self._load_failed_reqs: set[str] = set()
145+
self._load_req_to_blocks: dict[str, set[int]] = {}
144146
if (
145147
self._vllm_config.kv_transfer_config is not None
146148
and "ucm_connector_name"
@@ -280,6 +282,16 @@ def load(tensor_list, offset_list) -> tuple[Task, Task]:
280282
# ==============================
281283
# Worker-side methods
282284
# ==============================
285+
def clear_connector_metadata(self) -> None:
286+
"""Clear the connector metadata.
287+
288+
This function should be called by the model runner every time
289+
after the model execution.
290+
"""
291+
self._load_failed_reqs.clear()
292+
self._load_req_to_blocks.clear()
293+
super().clear_connector_metadata()
294+
283295
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
284296
"""
285297
Start loading the KV cache from the connector to vLLM's paged
@@ -306,7 +318,6 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
306318
for request in metadata.requests:
307319
if request.load_paras is None or not request.load_paras.can_load:
308320
continue
309-
layer_to_tensor: dict[str, tuple[List[torch.Tensor], List[int]]] = {}
310321
block_ids = request.vllm_block_ids
311322
# Blocks id need to save should start after last vllm cached block
312323
load_start_block_id = (
@@ -327,6 +338,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
327338
]
328339
assert len(fetch_block_ids) == len(fetch_block_hashes)
329340
blocks_len = len(fetch_block_ids)
341+
self._load_req_to_blocks.setdefault(request.request_id, set()).update(
342+
fetch_block_ids
343+
)
330344
for layer_name, kv_layer in self.kv_caches.items():
331345
tensors, offsets = self.get_tensor_and_offset_layerwise(
332346
fetch_block_ids, kv_layer, layer_name
@@ -363,9 +377,12 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
363377
for _, (k_task, v_task) in self.layerwise_load_tasks[
364378
request.request_id
365379
].items():
366-
assert self.connector.wait(k_task) == 0
367-
if not self.is_mla:
368-
assert self.connector.wait(v_task) == 0
380+
if self.connector.wait(k_task) != 0:
381+
self._load_failed_reqs.add(request.request_id)
382+
break
383+
if v_task and self.connector.wait(v_task) != 0:
384+
self._load_failed_reqs.add(request.request_id)
385+
break
369386

370387
def wait_for_layer_load(self, layer_name: str) -> None:
371388
"""
@@ -387,10 +404,16 @@ def wait_for_layer_load(self, layer_name: str) -> None:
387404
self.current_layer < self.num_layers
388405
), "The current layer should be less than total layers!"
389406
for request_id, layer_to_task in self.layerwise_load_tasks.items():
407+
if request_id in self._load_failed_reqs:
408+
continue
390409
k_task, v_task = layer_to_task[layer_name]
391-
assert self.connector.wait(k_task) == 0
410+
if self.connector.wait(k_task) != 0:
411+
self._load_failed_reqs.add(request_id)
412+
continue
392413
if not self.is_mla:
393-
assert self.connector.wait(v_task) == 0
414+
if self.connector.wait(v_task) != 0:
415+
self._load_failed_reqs.add(request_id)
416+
continue
394417
logger.debug(f"Load tasks for {request_id} on layer {layer_name} finished.")
395418

396419
def save_kv_layer(
@@ -550,16 +573,18 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
550573
"""Get the finished recving and sending requests."""
551574
done_recving: set[str] = set()
552575
for req_id, tasks in self._need_load_reqs.items():
576+
if req_id in self._load_failed_reqs:
577+
continue
553578
unfinished_tasks = []
554579
for task in tasks:
555580
ret = self.connector.check(task)
556-
if ret == 0:
557-
# Remove this assertion after support recompute load failed reqs
558-
assert self.connector.wait(task) == 0
581+
if ret == -1:
582+
unfinished_tasks.append(task)
583+
continue
584+
elif ret == 0 and self.connector.wait(task) == 0:
559585
continue
560-
if ret != -1:
561-
raise ValueError(f"Task {task.get_id()} Not Found")
562-
unfinished_tasks.append(task)
586+
self._load_failed_reqs.add(req_id)
587+
break
563588
if not unfinished_tasks:
564589
done_recving.add(req_id)
565590
self._need_load_reqs[req_id] = unfinished_tasks
@@ -624,10 +649,19 @@ def md5(input) -> int:
624649
)
625650

626651
need_load_tokens = max(num_external_computed_tokens - num_computed_tokens, 0)
652+
# Load async when Decode instance need to load.
627653
if hasattr(self, "kv_role") and self.kv_role == "kv_consumer":
654+
# Only trigger 1 asynchronous KV transfer per request.
655+
if (
656+
request.kv_transfer_params
657+
and request.kv_transfer_params["load_async"] == False
658+
):
659+
return 0, False
660+
request.kv_transfer_params = request.kv_transfer_params or {}
661+
request.kv_transfer_params["load_async"] = False
628662
if need_load_tokens > 0:
629663
self._need_load_reqs[request.request_id] = []
630-
return need_load_tokens, True
664+
return need_load_tokens, True
631665

632666
num_max_cached_tokens = max(num_external_computed_tokens, num_computed_tokens)
633667
num_blocks_need_save = (
@@ -778,6 +812,13 @@ def request_finished(
778812
self.connector.commit(cancel_blocks, False)
779813
return False, None
780814

815+
def get_block_ids_with_load_errors(self) -> set[int]:
816+
invalid_block_ids: set[int] = set()
817+
for req_id in self._load_failed_reqs:
818+
if req_id in self._load_req_to_blocks:
819+
invalid_block_ids.update(self._load_req_to_blocks[req_id])
820+
return invalid_block_ids
821+
781822
@staticmethod
782823
def _extract_layer_index(layer_name: str) -> Optional[int]:
783824
"""

0 commit comments

Comments
 (0)