Skip to content

Commit 5c191a2

Browse files
authored
[perf]prepare offset in advance (#188)
* prepare offset in advance * fix ci problem * fix mla offset error * scheduler do not need dataoffset
1 parent 62685d6 commit 5c191a2

File tree

2 files changed

+42
-32
lines changed

2 files changed

+42
-32
lines changed

test/test_uc_connector.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,19 @@ def setUp(self):
8181
self.block_size = 8
8282
self.num_layers = 48
8383
self.total_blocks_num = 40
84+
self.total_tp_size = 2
8485
self.kv_caches = {}
86+
self.k_data_offsets = {}
8587
for i in range(self.num_layers):
8688
layer_name = f"model.layers.{i}.self_attn.attn"
8789
kv_tensor = torch.rand(
8890
(2, self.total_blocks_num, self.block_size, 4, 8), dtype=torch.bfloat16
8991
)
9092
self.kv_caches[layer_name] = kv_tensor
93+
for layer_id in range(self.num_layers):
94+
self.k_data_offsets[layer_id] = {}
95+
for i in range(self.total_tp_size):
96+
self.k_data_offsets[layer_id][i] = 0
9197

9298
def init_uc(
9399
self, mock_connector, metadata=Mock(), use_layerwise=True
@@ -102,14 +108,16 @@ def init_uc(
102108
ucconnector.connector = mock_connector
103109
ucconnector.request_block_infos: dict[str, RequestBlockInfo] = {}
104110
ucconnector.dump_tasks: dict[str, dict[str, List[Task]]] = {}
105-
ucconnector.total_tp_size = 2
111+
ucconnector.total_tp_size = self.total_tp_size
106112
ucconnector._connector_metadata = metadata
107113
ucconnector.layerwise_load_tasks: dict[
108114
str, dict[str, tuple[Task, Task]]
109115
] = {}
110116
ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
111117
ucconnector._load_failed_reqs: set[str] = set()
112118
ucconnector._load_req_to_blocks: dict[str, set[int]] = {}
119+
ucconnector.k_data_offsets = self.k_data_offsets
120+
ucconnector.min_block_size = 0
113121
return ucconnector
114122

115123
def test_get_num_new_matched_tokens_hit_all_on_storage(self):

ucm/integration/vllm/uc_connector.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
109109
self._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
110110
self._load_failed_reqs: set[str] = set()
111111
self._load_req_to_blocks: dict[str, set[int]] = {}
112+
if role == KVConnectorRole.WORKER:
113+
self._initialize_dataoffset(vllm_config)
112114
if (
113115
self._vllm_config.kv_transfer_config is not None
114116
and "ucm_connector_name"
@@ -156,37 +158,34 @@ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"
156158
forward_context.virtual_engine
157159
]
158160

159-
def DataOffset(self, kv_layer, rank, layer_id, is_v):
160-
# Non-MLA scene: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
161-
# MLA scene: one layer shape is (num_blocks, block_size, head_size)
162-
# Element size
163-
elem_size = kv_layer[0].element_size()
164-
logger.debug(
165-
f"total_tp_size = {self.total_tp_size},\n" f"element size = {elem_size}."
161+
def _initialize_dataoffset(self, vllm_config: "VllmConfig"):
162+
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
163+
vllm_config.parallel_config
164+
)
165+
head_size = vllm_config.model_config.get_head_size()
166+
self.min_block_size = (
167+
self.block_size * num_kv_heads * head_size * self.element_size
166168
)
167-
# One block size
168-
k_min_data_block_size = (
169-
kv_layer[0][0].numel() if not self.is_mla else kv_layer[0].numel()
170-
) * elem_size
171-
v_min_data_block_size = (
172-
kv_layer[1][0].numel() if not self.is_mla else 0
173-
) * elem_size
174-
# When tp > 1 layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size
175169
layer_size = (
176-
k_min_data_block_size + v_min_data_block_size
177-
) * self.total_tp_size
178-
if is_v:
179-
# Offset of v = Offset of k + k_min_data_block_size
180-
return int(
181-
self.DataOffset(kv_layer, rank, layer_id, False) + k_min_data_block_size
182-
)
183-
if self.is_mla:
184-
return int(layer_size * layer_id)
185-
else:
186-
# Offset of k = layer_size * layer_id + layer_size / tp_size * current rank
187-
return int(
188-
layer_size * layer_id + layer_size / self.total_tp_size * self.rank
189-
)
170+
self.min_block_size * 2 * self.total_tp_size
171+
if not self.is_mla
172+
else self.min_block_size
173+
)
174+
# layer_id -> rank -> k_offset
175+
self.k_data_offsets: dict[int, dict[int, int]] = {}
176+
177+
for layer_id in range(self.num_layers):
178+
self.k_data_offsets[layer_id] = {}
179+
for rank in range(self.total_tp_size):
180+
if self.is_mla:
181+
self.k_data_offsets[layer_id][0] = layer_size * layer_id
182+
break
183+
else:
184+
offset = (
185+
layer_size * layer_id
186+
+ (layer_size // self.total_tp_size) * rank
187+
)
188+
self.k_data_offsets[layer_id][rank] = offset
190189

191190
def get_tensor_and_offset_layerwise(
192191
self, vllm_block_ids: List[int], kv_layer: torch.Tensor, layer_name: str
@@ -198,14 +197,17 @@ def get_tensor_and_offset_layerwise(
198197
layer_id = self._extract_layer_index(layer_name)
199198

200199
for blk_id in vllm_block_ids:
201-
k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False)
202200
if self.is_mla:
201+
k_data_offset = self.k_data_offsets[layer_id][0]
203202
k_tensors.append(kv_layer[blk_id])
204203
else:
204+
k_data_offset = self.k_data_offsets[layer_id][self.rank]
205205
k_tensors.append(kv_layer[0][blk_id])
206206
k_offsets.append(k_data_offset)
207207
if not self.is_mla:
208-
v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True)
208+
v_data_offset = (
209+
self.k_data_offsets[layer_id][self.rank] + self.min_block_size
210+
)
209211
v_tensors.append(kv_layer[1][blk_id])
210212
v_offsets.append(v_data_offset)
211213
return k_tensors + v_tensors, k_offsets + v_offsets

0 commit comments

Comments
 (0)