2525import random
2626import secrets
2727import unittest
28- from typing import List
28+ from typing import List , Union
2929from unittest .mock import MagicMock , Mock , patch
3030
3131import torch
3434from vllm .v1 .request import Request
3535
3636from ucm .integration .vllm .uc_connector import (
37- LoadPara ,
37+ BlockOperation ,
3838 ReqMeta ,
39- SavePara ,
39+ RequestBlockInfo ,
4040 UCConnectorV1Metadata ,
4141 UnifiedCacheConnectorV1 ,
4242)
@@ -100,10 +100,8 @@ def init_uc(
100100 ucconnector .rank = 1
101101 ucconnector .is_mla = False
102102 ucconnector .connector = mock_connector
103- ucconnector .load_paras : dict [str , LoadPara ] = {}
104- ucconnector .save_paras : dict [str , SavePara ] = {}
103+ ucconnector .request_block_infos : dict [str , RequestBlockInfo ] = {}
105104 ucconnector .dump_tasks : dict [str , dict [str , List [Task ]]] = {}
106- ucconnector .load_tasks : dict [str , tuple [Task , Task ]] = {}
107105 ucconnector .total_tp_size = 2
108106 ucconnector ._connector_metadata = metadata
109107 ucconnector .layerwise_load_tasks : dict [
@@ -114,14 +112,47 @@ def init_uc(
114112 ucconnector ._load_req_to_blocks : dict [str , set [int ]] = {}
115113 return ucconnector
116114
117- def test_get_num_new_matched_tokens_hit (self ):
115+ def test_get_num_new_matched_tokens_hit_all_on_storage (self ):
118116 mock_connector = Mock (spec = UcmKVStoreBase )
119117
120118 def mock_lookup (tokens : List [int ]) -> List [bool ]:
121119 return [True ] * self .block_number
122120
121+ mock_connector .lookup .side_effect = mock_lookup
122+ ucconnector = self .init_uc (mock_connector )
123+
124+ random .seed (20250704 )
125+ request1 = make_request (
126+ request_id = 1 ,
127+ prompt_token_ids = random .sample (
128+ range (0 , 10000 ), self .block_number * self .block_size
129+ ),
130+ mm_positions = None ,
131+ mm_hashes = None ,
132+ )
133+
134+ # all block dumped in ssd, external_tokens equals to full tokens num - self.block_size
135+ all_tokens_len = len (request1 .all_token_ids )
136+ external_tokens , _ = ucconnector .get_num_new_matched_tokens (request1 , 0 )
137+ self .assertEqual (external_tokens , all_tokens_len - self .block_size )
138+ self .assertEqual (
139+ ucconnector .request_block_infos [request1 .request_id ].block_operations ,
140+ [
141+ BlockOperation .LOAD ,
142+ BlockOperation .LOAD ,
143+ BlockOperation .LOAD ,
144+ BlockOperation .NONE ,
145+ ],
146+ )
147+
148+ def test_get_num_new_matched_tokens_partial_hit (self ):
149+ mock_connector = Mock (spec = UcmKVStoreBase )
150+
151+ def mock_lookup (tokens : List [int ]) -> List [bool ]:
152+ return [True , False , True , False ]
153+
123154 def mock_create (tokens : List [str ]) -> List [int ]:
124- return [1 ] * self . block_number
155+ return [0 , 1 , 0 ]
125156
126157 mock_connector .lookup .side_effect = mock_lookup
127158 mock_connector .create .side_effect = mock_create
@@ -137,10 +168,60 @@ def mock_create(tokens: List[str]) -> List[int]:
137168 mm_hashes = None ,
138169 )
139170
140- # all block dumped in ssd, external_tokens equals to full tokens num
171+ # all block dumped in ssd, external_tokens equals to full tokens num - self.block_size
141172 all_tokens_len = len (request1 .all_token_ids )
142173 external_tokens , _ = ucconnector .get_num_new_matched_tokens (request1 , 0 )
143- self .assertEqual (external_tokens , all_tokens_len - self .block_size )
174+ self .assertEqual (external_tokens , self .block_size )
175+ self .assertEqual (
176+ ucconnector .request_block_infos [request1 .request_id ].block_operations ,
177+ [
178+ BlockOperation .LOAD ,
179+ BlockOperation .DUMP ,
180+ BlockOperation .NONE ,
181+ BlockOperation .DUMP ,
182+ ],
183+ )
184+
185+ def test_get_num_new_matched_tokens_partial_hit_with_preftxcache (self ):
186+ mock_connector = Mock (spec = UcmKVStoreBase )
187+
188+ def mock_lookup (tokens : List [int ]) -> List [bool ]:
189+ return [False , True , False ]
190+
191+ def mock_create (tokens : List [str ]) -> List [int ]:
192+ return [0 , 1 , 0 ]
193+
194+ mock_connector .lookup .side_effect = mock_lookup
195+ mock_connector .create .side_effect = mock_create
196+ ucconnector = self .init_uc (mock_connector )
197+
198+ random .seed (20250704 )
199+ request1 = make_request (
200+ request_id = 1 ,
201+ prompt_token_ids = random .sample (
202+ range (0 , 10000 ), self .block_number * self .block_size
203+ ),
204+ mm_positions = None ,
205+ mm_hashes = None ,
206+ )
207+
208+ # no block dumped in ssd, external_tokens equals to 0
209+ external_tokens , _ = ucconnector .get_num_new_matched_tokens (
210+ request1 , self .block_size
211+ )
212+ self .assertEqual (external_tokens , 0 )
213+ self .assertEqual (
214+ ucconnector .request_block_infos [request1 .request_id ].start_position , 1
215+ )
216+ self .assertEqual (
217+ ucconnector .request_block_infos [request1 .request_id ].block_operations ,
218+ [
219+ BlockOperation .NONE ,
220+ BlockOperation .DUMP ,
221+ BlockOperation .NONE ,
222+ BlockOperation .DUMP ,
223+ ],
224+ )
144225
145226 def test_get_num_new_matched_tokens_no_hit (self ):
146227 mock_connector = Mock (spec = UcmKVStoreBase )
@@ -149,7 +230,7 @@ def mock_lookup(tokens: List[int]) -> List[bool]:
149230 return [False ] * self .block_number
150231
151232 def mock_create (tokens : List [str ]) -> List [int ]:
152- return [1 ] * self .block_number
233+ return [0 ] * self .block_number
153234
154235 mock_connector .lookup .side_effect = mock_lookup
155236 mock_connector .create .side_effect = mock_create
@@ -192,15 +273,9 @@ def test_get_num_new_matched_tokens_invalid_para(self):
192273 def test_wait_for_save_not_layerwise_success (self ):
193274 req_meta1 = MagicMock (spec = ReqMeta )
194275 req_meta1 .request_id = "req1"
195- req_meta1 .save_paras = SavePara (
196- num_blocks_need_save = self .block_number ,
197- start_save_position = 0 ,
198- num_blocks_to_save = self .block_number ,
199- )
200- req_meta1 .save_paras .block_hashes = [
201- secrets .token_hex (8 ) for _ in range (self .block_number )
276+ req_meta1 .dump_blocks = [
277+ (secrets .token_hex (8 ), i ) for i in range (self .block_number )
202278 ]
203- req_meta1 .vllm_block_ids = list (range (self .block_number ))
204279
205280 metadata = UCConnectorV1Metadata ()
206281 metadata .requests = [req_meta1 ]
@@ -236,15 +311,10 @@ def test_wait_for_save_not_layerwise_invalid_para(self):
236311 def test_start_load_kv_not_layerwise_success (self ):
237312 req_meta1 = MagicMock (spec = ReqMeta )
238313 req_meta1 .request_id = "req1"
239- req_meta1 .load_paras = LoadPara (
240- vllm_cached_tokens = 1 * self .block_size ,
241- storage_cached_tokens = self .block_number * self .block_size ,
242- can_load = True ,
243- )
244- req_meta1 .load_paras .block_hashes = [
245- secrets .token_hex (8 ) for _ in range (self .block_number )
314+ req_meta1 .load_blocks = [
315+ (secrets .token_hex (8 ), i ) for i in range (self .block_number )
246316 ]
247- req_meta1 .vllm_block_ids = list ( range ( self . block_number ))
317+ req_meta1 .load_async = False
248318
249319 metadata = UCConnectorV1Metadata ()
250320 metadata .requests = [req_meta1 ]
@@ -282,15 +352,9 @@ def test_start_load_kv_invalid_para(self):
282352 def test_start_load_kv_layerwise_success (self ):
283353 req_meta1 = MagicMock (spec = ReqMeta )
284354 req_meta1 .request_id = "req1"
285- req_meta1 .load_paras = LoadPara (
286- vllm_cached_tokens = 1 * self .block_size ,
287- storage_cached_tokens = self .block_number * self .block_size ,
288- can_load = True ,
289- )
290- req_meta1 .load_paras .block_hashes = [
291- secrets .token_hex (8 ) for _ in range (self .block_number )
355+ req_meta1 .load_blocks = [
356+ (secrets .token_hex (8 ), i ) for i in range (self .block_number )
292357 ]
293- req_meta1 .vllm_block_ids = list (range (self .block_number ))
294358
295359 metadata = UCConnectorV1Metadata ()
296360 metadata .requests = [req_meta1 ]
@@ -309,89 +373,6 @@ def mock_load(
309373 ucconnector .start_load_kv (forward_context )
310374 assert mock_connector .load .call_count == 2 * self .num_layers
311375
312- def test_generate_layerwise_load_tasks_success (self ):
313- # init implement
314- mock_connector = Mock (spec = UcmKVStoreBase )
315-
316- def mock_load (
317- block_ids : List [str ], offset : List [int ], dst_tensor : List [torch .Tensor ]
318- ) -> Task :
319- assert offset is not None and offset
320- assert dst_tensor is not None and dst_tensor
321- return Task ()
322-
323- mock_connector .load .side_effect = mock_load
324- ucconnector = self .init_uc (mock_connector )
325-
326- # provide generate_layerwise_load_tasks params
327- fetch_block_ids = list (range (self .block_number * 2 ))
328- fetch_block_hashes = [
329- secrets .token_hex (8 ) for _ in range (self .block_number * 2 )
330- ]
331- layer_to_tensor : dict [str , tuple [List [torch .Tensor ], List [int ]]] = {}
332- current_layer = 0
333- for layer_name , kv_layer in self .kv_caches .items ():
334- tensors , offsets = ucconnector .get_tensor_and_offset_layerwise (
335- fetch_block_ids , kv_layer , layer_name
336- )
337- layer_to_tensor [layer_name ] = (tensors , offsets )
338- current_layer += 1
339- # generate layerwise tasks
340- layerwise_load_task = ucconnector .generate_layerwise_load_tasks (
341- fetch_block_hashes , layer_to_tensor
342- )
343-
344- for i in range (self .num_layers ):
345- task = next (layerwise_load_task )
346- assert task is not None , f"layer { i } is None"
347- assert mock_connector .load .call_count == self .num_layers * 2
348-
349- def test_generate_layerwise_load_tasks_invalid_params (self ):
350- # init implement
351- mock_connector = Mock (spec = UcmKVStoreBase )
352-
353- def mock_load (
354- block_ids : List [str ], offset : List [int ], dst_tensor : List [torch .Tensor ]
355- ) -> Task :
356- assert offset is not None and offset
357- assert dst_tensor is not None and dst_tensor
358- return Task ()
359-
360- mock_connector .load .side_effect = mock_load
361- ucconnector = self .init_uc (mock_connector )
362-
363- # provide generate_layerwise_load_tasks params
364- fetch_block_ids = list (range (self .block_number * 2 ))
365- fetch_block_hashes = [
366- secrets .token_hex (8 ) for _ in range (self .block_number * 2 )
367- ]
368- layer_to_tensor : dict [str , tuple [List [torch .Tensor ], List [int ]]] = {}
369- for layer_name , kv_layer in self .kv_caches .items ():
370- tensors , offsets = ucconnector .get_tensor_and_offset_layerwise (
371- fetch_block_ids , kv_layer , layer_name
372- )
373- layer_to_tensor [layer_name ] = (tensors , offsets )
374- # generate layerwise tasks
375- layerwise_load_task = ucconnector .generate_layerwise_load_tasks (
376- [], layer_to_tensor
377- )
378- with self .assertRaises (AssertionError ) as context :
379- next (layerwise_load_task )
380- self .assertEqual (
381- str (context .exception ),
382- "The block hashes need to be fetched should not be None or empty." ,
383- )
384-
385- layerwise_load_task = ucconnector .generate_layerwise_load_tasks (
386- fetch_block_hashes , None
387- )
388- with self .assertRaises (AssertionError ) as context :
389- next (layerwise_load_task )
390- self .assertEqual (
391- str (context .exception ),
392- "The layers of tensor need to be fetched should not be None or empty." ,
393- )
394-
395376
396377if __name__ == "__main__" :
397378 unittest .main ()
0 commit comments