Skip to content

Commit 62ed31f

Browse files
authored
[Feature]refactor ucconnector (#167)
* refactor ucconnector * fic comment
1 parent 3d6c86e commit 62ed31f

File tree

3 files changed

+302
-372
lines changed

3 files changed

+302
-372
lines changed

test/test_uc_connector.py

Lines changed: 99 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import random
2626
import secrets
2727
import unittest
28-
from typing import List
28+
from typing import List, Union
2929
from unittest.mock import MagicMock, Mock, patch
3030

3131
import torch
@@ -34,9 +34,9 @@
3434
from vllm.v1.request import Request
3535

3636
from ucm.integration.vllm.uc_connector import (
37-
LoadPara,
37+
BlockOperation,
3838
ReqMeta,
39-
SavePara,
39+
RequestBlockInfo,
4040
UCConnectorV1Metadata,
4141
UnifiedCacheConnectorV1,
4242
)
@@ -100,10 +100,8 @@ def init_uc(
100100
ucconnector.rank = 1
101101
ucconnector.is_mla = False
102102
ucconnector.connector = mock_connector
103-
ucconnector.load_paras: dict[str, LoadPara] = {}
104-
ucconnector.save_paras: dict[str, SavePara] = {}
103+
ucconnector.request_block_infos: dict[str, RequestBlockInfo] = {}
105104
ucconnector.dump_tasks: dict[str, dict[str, List[Task]]] = {}
106-
ucconnector.load_tasks: dict[str, tuple[Task, Task]] = {}
107105
ucconnector.total_tp_size = 2
108106
ucconnector._connector_metadata = metadata
109107
ucconnector.layerwise_load_tasks: dict[
@@ -114,14 +112,47 @@ def init_uc(
114112
ucconnector._load_req_to_blocks: dict[str, set[int]] = {}
115113
return ucconnector
116114

117-
def test_get_num_new_matched_tokens_hit(self):
115+
def test_get_num_new_matched_tokens_hit_all_on_storage(self):
118116
mock_connector = Mock(spec=UcmKVStoreBase)
119117

120118
def mock_lookup(tokens: List[int]) -> List[bool]:
121119
return [True] * self.block_number
122120

121+
mock_connector.lookup.side_effect = mock_lookup
122+
ucconnector = self.init_uc(mock_connector)
123+
124+
random.seed(20250704)
125+
request1 = make_request(
126+
request_id=1,
127+
prompt_token_ids=random.sample(
128+
range(0, 10000), self.block_number * self.block_size
129+
),
130+
mm_positions=None,
131+
mm_hashes=None,
132+
)
133+
134+
# all block dumped in ssd, external_tokens equals to full tokens num - self.block_size
135+
all_tokens_len = len(request1.all_token_ids)
136+
external_tokens, _ = ucconnector.get_num_new_matched_tokens(request1, 0)
137+
self.assertEqual(external_tokens, all_tokens_len - self.block_size)
138+
self.assertEqual(
139+
ucconnector.request_block_infos[request1.request_id].block_operations,
140+
[
141+
BlockOperation.LOAD,
142+
BlockOperation.LOAD,
143+
BlockOperation.LOAD,
144+
BlockOperation.NONE,
145+
],
146+
)
147+
148+
def test_get_num_new_matched_tokens_partial_hit(self):
149+
mock_connector = Mock(spec=UcmKVStoreBase)
150+
151+
def mock_lookup(tokens: List[int]) -> List[bool]:
152+
return [True, False, True, False]
153+
123154
def mock_create(tokens: List[str]) -> List[int]:
124-
return [1] * self.block_number
155+
return [0, 1, 0]
125156

126157
mock_connector.lookup.side_effect = mock_lookup
127158
mock_connector.create.side_effect = mock_create
@@ -137,10 +168,60 @@ def mock_create(tokens: List[str]) -> List[int]:
137168
mm_hashes=None,
138169
)
139170

140-
# all block dumped in ssd, external_tokens equals to full tokens num
171+
# all block dumped in ssd, external_tokens equals to full tokens num - self.block_size
141172
all_tokens_len = len(request1.all_token_ids)
142173
external_tokens, _ = ucconnector.get_num_new_matched_tokens(request1, 0)
143-
self.assertEqual(external_tokens, all_tokens_len - self.block_size)
174+
self.assertEqual(external_tokens, self.block_size)
175+
self.assertEqual(
176+
ucconnector.request_block_infos[request1.request_id].block_operations,
177+
[
178+
BlockOperation.LOAD,
179+
BlockOperation.DUMP,
180+
BlockOperation.NONE,
181+
BlockOperation.DUMP,
182+
],
183+
)
184+
185+
def test_get_num_new_matched_tokens_partial_hit_with_preftxcache(self):
186+
mock_connector = Mock(spec=UcmKVStoreBase)
187+
188+
def mock_lookup(tokens: List[int]) -> List[bool]:
189+
return [False, True, False]
190+
191+
def mock_create(tokens: List[str]) -> List[int]:
192+
return [0, 1, 0]
193+
194+
mock_connector.lookup.side_effect = mock_lookup
195+
mock_connector.create.side_effect = mock_create
196+
ucconnector = self.init_uc(mock_connector)
197+
198+
random.seed(20250704)
199+
request1 = make_request(
200+
request_id=1,
201+
prompt_token_ids=random.sample(
202+
range(0, 10000), self.block_number * self.block_size
203+
),
204+
mm_positions=None,
205+
mm_hashes=None,
206+
)
207+
208+
# no block dumped in ssd, external_tokens equals to 0
209+
external_tokens, _ = ucconnector.get_num_new_matched_tokens(
210+
request1, self.block_size
211+
)
212+
self.assertEqual(external_tokens, 0)
213+
self.assertEqual(
214+
ucconnector.request_block_infos[request1.request_id].start_position, 1
215+
)
216+
self.assertEqual(
217+
ucconnector.request_block_infos[request1.request_id].block_operations,
218+
[
219+
BlockOperation.NONE,
220+
BlockOperation.DUMP,
221+
BlockOperation.NONE,
222+
BlockOperation.DUMP,
223+
],
224+
)
144225

145226
def test_get_num_new_matched_tokens_no_hit(self):
146227
mock_connector = Mock(spec=UcmKVStoreBase)
@@ -149,7 +230,7 @@ def mock_lookup(tokens: List[int]) -> List[bool]:
149230
return [False] * self.block_number
150231

151232
def mock_create(tokens: List[str]) -> List[int]:
152-
return [1] * self.block_number
233+
return [0] * self.block_number
153234

154235
mock_connector.lookup.side_effect = mock_lookup
155236
mock_connector.create.side_effect = mock_create
@@ -192,15 +273,9 @@ def test_get_num_new_matched_tokens_invalid_para(self):
192273
def test_wait_for_save_not_layerwise_success(self):
193274
req_meta1 = MagicMock(spec=ReqMeta)
194275
req_meta1.request_id = "req1"
195-
req_meta1.save_paras = SavePara(
196-
num_blocks_need_save=self.block_number,
197-
start_save_position=0,
198-
num_blocks_to_save=self.block_number,
199-
)
200-
req_meta1.save_paras.block_hashes = [
201-
secrets.token_hex(8) for _ in range(self.block_number)
276+
req_meta1.dump_blocks = [
277+
(secrets.token_hex(8), i) for i in range(self.block_number)
202278
]
203-
req_meta1.vllm_block_ids = list(range(self.block_number))
204279

205280
metadata = UCConnectorV1Metadata()
206281
metadata.requests = [req_meta1]
@@ -236,15 +311,10 @@ def test_wait_for_save_not_layerwise_invalid_para(self):
236311
def test_start_load_kv_not_layerwise_success(self):
237312
req_meta1 = MagicMock(spec=ReqMeta)
238313
req_meta1.request_id = "req1"
239-
req_meta1.load_paras = LoadPara(
240-
vllm_cached_tokens=1 * self.block_size,
241-
storage_cached_tokens=self.block_number * self.block_size,
242-
can_load=True,
243-
)
244-
req_meta1.load_paras.block_hashes = [
245-
secrets.token_hex(8) for _ in range(self.block_number)
314+
req_meta1.load_blocks = [
315+
(secrets.token_hex(8), i) for i in range(self.block_number)
246316
]
247-
req_meta1.vllm_block_ids = list(range(self.block_number))
317+
req_meta1.load_async = False
248318

249319
metadata = UCConnectorV1Metadata()
250320
metadata.requests = [req_meta1]
@@ -282,15 +352,9 @@ def test_start_load_kv_invalid_para(self):
282352
def test_start_load_kv_layerwise_success(self):
283353
req_meta1 = MagicMock(spec=ReqMeta)
284354
req_meta1.request_id = "req1"
285-
req_meta1.load_paras = LoadPara(
286-
vllm_cached_tokens=1 * self.block_size,
287-
storage_cached_tokens=self.block_number * self.block_size,
288-
can_load=True,
289-
)
290-
req_meta1.load_paras.block_hashes = [
291-
secrets.token_hex(8) for _ in range(self.block_number)
355+
req_meta1.load_blocks = [
356+
(secrets.token_hex(8), i) for i in range(self.block_number)
292357
]
293-
req_meta1.vllm_block_ids = list(range(self.block_number))
294358

295359
metadata = UCConnectorV1Metadata()
296360
metadata.requests = [req_meta1]
@@ -309,89 +373,6 @@ def mock_load(
309373
ucconnector.start_load_kv(forward_context)
310374
assert mock_connector.load.call_count == 2 * self.num_layers
311375

312-
def test_generate_layerwise_load_tasks_success(self):
313-
# init implement
314-
mock_connector = Mock(spec=UcmKVStoreBase)
315-
316-
def mock_load(
317-
block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor]
318-
) -> Task:
319-
assert offset is not None and offset
320-
assert dst_tensor is not None and dst_tensor
321-
return Task()
322-
323-
mock_connector.load.side_effect = mock_load
324-
ucconnector = self.init_uc(mock_connector)
325-
326-
# provide generate_layerwise_load_tasks params
327-
fetch_block_ids = list(range(self.block_number * 2))
328-
fetch_block_hashes = [
329-
secrets.token_hex(8) for _ in range(self.block_number * 2)
330-
]
331-
layer_to_tensor: dict[str, tuple[List[torch.Tensor], List[int]]] = {}
332-
current_layer = 0
333-
for layer_name, kv_layer in self.kv_caches.items():
334-
tensors, offsets = ucconnector.get_tensor_and_offset_layerwise(
335-
fetch_block_ids, kv_layer, layer_name
336-
)
337-
layer_to_tensor[layer_name] = (tensors, offsets)
338-
current_layer += 1
339-
# generate layerwise tasks
340-
layerwise_load_task = ucconnector.generate_layerwise_load_tasks(
341-
fetch_block_hashes, layer_to_tensor
342-
)
343-
344-
for i in range(self.num_layers):
345-
task = next(layerwise_load_task)
346-
assert task is not None, f"layer {i} is None"
347-
assert mock_connector.load.call_count == self.num_layers * 2
348-
349-
def test_generate_layerwise_load_tasks_invalid_params(self):
350-
# init implement
351-
mock_connector = Mock(spec=UcmKVStoreBase)
352-
353-
def mock_load(
354-
block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor]
355-
) -> Task:
356-
assert offset is not None and offset
357-
assert dst_tensor is not None and dst_tensor
358-
return Task()
359-
360-
mock_connector.load.side_effect = mock_load
361-
ucconnector = self.init_uc(mock_connector)
362-
363-
# provide generate_layerwise_load_tasks params
364-
fetch_block_ids = list(range(self.block_number * 2))
365-
fetch_block_hashes = [
366-
secrets.token_hex(8) for _ in range(self.block_number * 2)
367-
]
368-
layer_to_tensor: dict[str, tuple[List[torch.Tensor], List[int]]] = {}
369-
for layer_name, kv_layer in self.kv_caches.items():
370-
tensors, offsets = ucconnector.get_tensor_and_offset_layerwise(
371-
fetch_block_ids, kv_layer, layer_name
372-
)
373-
layer_to_tensor[layer_name] = (tensors, offsets)
374-
# generate layerwise tasks
375-
layerwise_load_task = ucconnector.generate_layerwise_load_tasks(
376-
[], layer_to_tensor
377-
)
378-
with self.assertRaises(AssertionError) as context:
379-
next(layerwise_load_task)
380-
self.assertEqual(
381-
str(context.exception),
382-
"The block hashes need to be fetched should not be None or empty.",
383-
)
384-
385-
layerwise_load_task = ucconnector.generate_layerwise_load_tasks(
386-
fetch_block_hashes, None
387-
)
388-
with self.assertRaises(AssertionError) as context:
389-
next(layerwise_load_task)
390-
self.assertEqual(
391-
str(context.exception),
392-
"The layers of tensor need to be fetched should not be None or empty.",
393-
)
394-
395376

396377
if __name__ == "__main__":
397378
unittest.main()

0 commit comments

Comments
 (0)