Skip to content

Commit 0919b19

Browse files
author
wuhuxiao
committed
blend ready
1 parent 43bdf01 commit 0919b19

File tree

15 files changed

+1890
-56
lines changed

15 files changed

+1890
-56
lines changed
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import contextlib
2+
import csv
3+
import json
4+
import os
5+
import random
6+
import re
7+
import time
8+
from dataclasses import asdict
9+
10+
from tqdm import tqdm
11+
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Vector
12+
13+
random.seed(0)
14+
15+
import sys
16+
17+
from transformers import AutoTokenizer
18+
from vllm import LLM, SamplingParams
19+
from vllm.config import KVTransferConfig
20+
from vllm.engine.arg_utils import EngineArgs
21+
from vllm.inputs import TokensPrompt
22+
23+
from ucm.logger import init_logger
24+
25+
logger = init_logger(__name__)
26+
27+
model = ""
28+
data_dir = ""
29+
path_to_dataset = ""
30+
tokenizer = None
31+
# 28705 is the token id for <space> char in llama model
32+
# 151643 is the pad token id in qwen model
33+
chunk_end_token_id = -1
34+
chunk_pad_token_id = -1
35+
block_size = 64
36+
37+
38+
def setup_environment_variables():
39+
os.environ["VLLM_USE_V1"] = "1"
40+
os.environ["PYTHONHASHSEED"] = "123456"
41+
42+
global model, data_dir, path_to_dataset, tokenizer, chunk_end_token_id, chunk_pad_token_id
43+
model = os.getenv("MODEL_PATH", "/home/models/mistralai/Mistral-7B-Instruct-v0.2")
44+
if not os.path.isdir(model):
45+
model = input(
46+
"Enter path to model, e.g./home/models/mistralai/Mistral-7B-Instruct-v0.2: "
47+
)
48+
if not os.path.isdir(model):
49+
print("Exiting. Incorrect model_path")
50+
sys.exit(1)
51+
52+
data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache")
53+
if not os.path.isdir(data_dir):
54+
data_dir = input(
55+
"Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
56+
)
57+
create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ")
58+
if create.lower() == "y":
59+
os.makedirs(data_dir, exist_ok=True)
60+
else:
61+
print("Exiting. Directory not created.")
62+
sys.exit(1)
63+
64+
# now support wikimqa
65+
path_to_dataset = os.getenv(
66+
"BLEND_DATASET_PATH", "/home/data/Longbench/data/2wikimqa.jsonl"
67+
)
68+
if not os.path.isfile(path_to_dataset):
69+
path_to_dataset = input(
70+
"Enter path of one of 2wikimqa dataset in longbench, e.g. /home/data/Longbench/data/2wikimqa.jsonl: "
71+
)
72+
if not os.path.isfile(path_to_dataset):
73+
print("Exiting. Incorrect dataset path")
74+
sys.exit(1)
75+
76+
tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True)
77+
# as for Qwen model, use pad_token_id for padding block
78+
# as for Llama model, current use unk_token for padding block
79+
chunk_pad_token_id = tokenizer.encode("▁", add_special_tokens=False)[0]
80+
chunk_end_token_id = chunk_pad_token_id
81+
82+
if tokenizer.pad_token_id is not None:
83+
chunk_pad_token_id = tokenizer.pad_token_id
84+
chunk_end_token_id = tokenizer.pad_token_id
85+
86+
87+
@contextlib.contextmanager
88+
def build_llm_with_uc(module_path: str, name: str, model: str):
89+
ktc = KVTransferConfig(
90+
kv_connector=name,
91+
kv_connector_module_path=module_path,
92+
kv_role="kv_both",
93+
kv_connector_extra_config={
94+
"ucm_connectors": [
95+
{
96+
"ucm_connector_name": "UcmNfsStore",
97+
"ucm_connector_config": {
98+
"storage_backends": data_dir,
99+
"kv_block_size": 33554432,
100+
},
101+
}
102+
],
103+
"load_only_first_rank": False,
104+
"ucm_sparse_config": {
105+
"Blend": {
106+
"chunk_end_token_id": chunk_end_token_id,
107+
"compute_meta": {
108+
"model.layers.1.self_attn.attn": {
109+
"ratio": 0.2,
110+
},
111+
},
112+
}
113+
},
114+
"use_layerwise": True,
115+
},
116+
)
117+
118+
llm_args = EngineArgs(
119+
model=model,
120+
enforce_eager=True,
121+
kv_transfer_config=ktc,
122+
max_model_len=16384 * 2,
123+
max_num_batched_tokens=16384 * 2,
124+
gpu_memory_utilization=0.8,
125+
block_size=block_size,
126+
enable_prefix_caching=False,
127+
distributed_executor_backend="mp",
128+
tensor_parallel_size=1,
129+
trust_remote_code=True,
130+
)
131+
132+
llm = LLM(**asdict(llm_args))
133+
try:
134+
yield llm
135+
finally:
136+
logger.info("LLM engine is exiting.")
137+
138+
139+
def get_output(
140+
llm: LLM,
141+
prompt,
142+
sampling_params: SamplingParams,
143+
):
144+
start = time.time()
145+
outputs = llm.generate(prompt, sampling_params)
146+
print("-" * 50)
147+
generated_text = None
148+
for output in outputs:
149+
generated_text = output.outputs[0].text
150+
e2e_time = time.time() - start
151+
print("-" * 50)
152+
return e2e_time, generated_text
153+
154+
155+
def pad_rag_chunks(token_ids, block_size, pad_id, end_id):
156+
"""
157+
pad token_ids with pad_id and end up with end_id
158+
"""
159+
# assert pad_id != end_id
160+
remainder = len(token_ids) % block_size
161+
162+
if remainder == 0 and token_ids[-1] in [pad_id, end_id]:
163+
# no need to pad
164+
token_ids[-1] = end_id
165+
return token_ids
166+
167+
pad_len = block_size - remainder - 1
168+
padded = token_ids + [pad_id] * pad_len + [end_id]
169+
return padded
170+
171+
172+
systemPrompt = """
173+
You are a helpful assistant.
174+
Please read the following Passages and answer the Question below.
175+
"""
176+
177+
178+
def main():
179+
module_path = "ucm.integration.vllm.blend_connector"
180+
name = "UCMBlendConnector"
181+
182+
setup_environment_variables()
183+
184+
with build_llm_with_uc(module_path, name, model) as llm:
185+
prefill_sampling_params = SamplingParams(
186+
temperature=0.0, top_p=0.95, max_tokens=1
187+
)
188+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=128)
189+
# choose one data row in LongBenchV1 (wikimqa)
190+
assert os.path.isfile(
191+
path_to_dataset
192+
), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`"
193+
with open(path_to_dataset, "r") as f:
194+
lines = f.readlines()
195+
dataset_row = json.loads(lines[0])
196+
197+
passages = re.findall(
198+
r"Passage\s+(\d+):(.*?)(?=Passage\s+\d+:|$)", dataset_row["context"], re.S
199+
)
200+
chunks = [f"Passage {i}:{passages[i][1]}" for i in range(len(passages))]
201+
question = "\n Question: " + dataset_row["input"] + "Answer within 5 words."
202+
origin_sys_prompt_ids = tokenizer.encode(systemPrompt)
203+
padded_sys_prompt_ids = pad_rag_chunks(
204+
origin_sys_prompt_ids, block_size, chunk_pad_token_id, chunk_end_token_id
205+
)
206+
# 1. sys prompt warm up
207+
print(f"---------------1. sys prompt: warm up---------------")
208+
get_output(
209+
llm,
210+
TokensPrompt(prompt_token_ids=padded_sys_prompt_ids),
211+
prefill_sampling_params,
212+
)
213+
time.sleep(0.5)
214+
215+
padded_contexts_ids = []
216+
padded_prompt_ids = padded_sys_prompt_ids
217+
origin_prompt_ids = origin_sys_prompt_ids
218+
for text_chunk in chunks:
219+
un_pad_ids = tokenizer.encode(text_chunk, add_special_tokens=False)
220+
padded_ids = pad_rag_chunks(
221+
un_pad_ids, block_size, chunk_pad_token_id, chunk_end_token_id
222+
)
223+
padded_prompt_ids = padded_prompt_ids + padded_ids
224+
origin_prompt_ids = origin_prompt_ids + un_pad_ids
225+
padded_contexts_ids.append(padded_ids)
226+
227+
question_ids = tokenizer.encode(question, add_special_tokens=False)
228+
padded_prompt_ids = padded_prompt_ids + question_ids
229+
origin_prompt_ids = origin_prompt_ids + question_ids
230+
231+
print(f"--------------- baseline with no cache blend ---------------")
232+
baseline_time, baseline_gen_text = get_output(
233+
llm, TokensPrompt(prompt_token_ids=origin_prompt_ids), sampling_params
234+
)
235+
time.sleep(0.5)
236+
237+
print(f"--------------- cache rag chunks ---------------")
238+
llm.generate(
239+
[TokensPrompt(prompt_token_ids=ids) for ids in padded_contexts_ids],
240+
sampling_params,
241+
)
242+
time.sleep(0.5)
243+
244+
print(f"--------------- warm up blend code ---------------")
245+
warm_up_blend_prompt_ids = padded_sys_prompt_ids
246+
for ids in reversed(padded_contexts_ids):
247+
warm_up_blend_prompt_ids = warm_up_blend_prompt_ids + ids
248+
warm_up_blend_prompt_ids = warm_up_blend_prompt_ids + question_ids
249+
llm.generate(
250+
TokensPrompt(prompt_token_ids=warm_up_blend_prompt_ids), sampling_params
251+
)
252+
time.sleep(0.5)
253+
254+
print(f"--------------- cache blend ---------------")
255+
blend_time, blend_gen_text = get_output(
256+
llm, TokensPrompt(prompt_token_ids=padded_prompt_ids), sampling_params
257+
)
258+
time.sleep(0.5)
259+
260+
print(f"--------------- prefix cache ---------------")
261+
pc_time, pc_gen_text = get_output(
262+
llm, TokensPrompt(prompt_token_ids=origin_prompt_ids), sampling_params
263+
)
264+
265+
print(f"Baseline generated text: {baseline_gen_text!r}")
266+
print(f"Baseline generated cost time: {baseline_time:.2f} seconds")
267+
print(f"Blend generated text: {blend_gen_text!r}")
268+
print(f"Blend generated cost time: {blend_time:.2f} seconds")
269+
print(f"Prefix Cache generated text: {pc_gen_text!r}")
270+
print(f"Prefix Cache generated cost time: {pc_time:.2f} seconds")
271+
print(f"Question:{dataset_row['input']}")
272+
print(f"Golden answer:{dataset_row["answers"]}")
273+
274+
275+
if __name__ == "__main__":
276+
main()

0 commit comments

Comments
 (0)