diff --git a/docs/source/_static/images/blend_scheme.jpg b/docs/source/_static/images/blend_scheme.jpg new file mode 100644 index 000000000..75638b68c Binary files /dev/null and b/docs/source/_static/images/blend_scheme.jpg differ diff --git a/docs/source/user-guide/sparse-attention/cacheblend.md b/docs/source/user-guide/sparse-attention/cacheblend.md new file mode 100644 index 000000000..0f5d8e819 --- /dev/null +++ b/docs/source/user-guide/sparse-attention/cacheblend.md @@ -0,0 +1,109 @@ +# CacheBlend: : Fast Large Language Model Serving for RAG with Cached Knowledge Fusion +
+ +![blend_scheme.jpg](../../_static/images/blend_scheme.jpg) + +**πŸš€ Knowledge Cached Fusion Algorithm | πŸ“„ EuroSys 2025 Paper ** + +[![License](https://img.shields.io/badge/License-MIT-green.svg)](https://github.com/ModelEngine-Group/unified-cache-management/blob/main/LICENSE) +[![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)](https://python.org) + +
+ +## 🌟 What is CacheBlend? + +**CacheBlend** is a cached fusion system that combines multiple pre-computed KV caches, when their corresponding texts +are concatenated in the LLM input. By selectively recomputing the KV cache values of a small fraction of tokens, +CacheBlend reduces TTFT by 2.2 ~ 3.3Γ— and increases throughput by 2.8 ~ 5Γ— under negligible quality drop. +### 🎯 Key Component + +- **πŸ” Loading Controller**: the Loading Controller orchestrates which KV caches to load, where from, and how much recomputation is needed. +- **⚑ KV Cache Store**: the KV Cache Store manages persistent storage, lookup, and eviction of precomputed KV caches keyed by text-chunk identity. +- **πŸŽ›οΈ Cache Fusor**: the Fusor merges multiple chunk-level caches into one coherent, cross-attention–correct KV cache, using minimal recomputation. + +### πŸ”₯ Key Results +- **2.2 ~ 3.3Γ— speedup** of TTFT and **2.8 ~ 5Γ— increase** of throughput for long sequences +- **Preserve High quality** no more than (1% ~ 3%) quality drop compared to full KV recompute + +## 🧠 Ucm Implementation + +### Native Block-Wise Chunk KV Cache Dump, Load, PostProcess and Recompute +1. **πŸ” Chunk Hash Encoding**: Similar as prefix hash encoder, hash all blocks in each chunk from the same hash meta beginning. +2. **⚑ Combine Prefix Cache and Chunk Cache**: Since chunk cache and native prefix cache share the same hash space, ucm first performs prefix cache lookup to fetch fully reused cache and then conduct chunk cache lookup to fetch the candidate cache for blending. +3. **🎯 Delta-Rope PostProcess**: Rectify loaded chunk cache according to their position in the new request. +3. **πŸ” Integrate Cache Blend and First Token Generation**: Construct compute mask and attention meta according to HKVD tokens, cache miss tokens and suffix tokens, then compute their kv cache in a single model forward stage. +4. **πŸš€ Comprehensive Hook for LLM Forward Pipeline**: Based on ucm sparse module, blend module sparse the prefill tokens not only in attention stage but also in ffn, layer stage. + +## πŸš€ Quick Start + +### Installation + +Blend is part of the UCM Sparse Attention module. For installation instructions, please refer to the [UCM's top-level README](https://github.com/ModelEngine-Group/unified-cache-management). Once UCM is installed, Blend is naturally supported by running the following example python scripts. + +```bash +export ENABLE_SPARSE=TRUE +export DATA_DIR=/home/data/kv_cache +export MODEL_PATH=/home/models/mistralai/Mistral-7B-Instruct-v0.2 +export BLEND_DATASET_PATH=/home/datasets/LongBench/data/2wikimqa.jsonl +python /examples/offline_inference_blend.py +``` + +### Basic Usage +Similar to UCM's `offline_inference_esa.py` examples. We only need to specify `ucm_sparse_method` to be `Blend` and specify meta config, as shown below. + +```python +... +ktc = KVTransferConfig( + kv_connector=name, + kv_connector_module_path=module_path, + kv_role="kv_both", + kv_connector_extra_config={ + "ucm_connectors": [ + { + "ucm_connector_name": "UcmNfsStore", + "ucm_connector_config": { + "storage_backends": data_dir, + "kv_block_size": 33554432, + }, + } + ], + "load_only_first_rank": False, + "ucm_sparse_config": { + "Blend": { + "chunk_end_token_id": chunk_end_token_id, + "compute_meta": { + "model.layers.1.self_attn.attn": { + "ratio": 0.2, + }, + }, + } + }, + "use_layerwise": True, + }, + ) +... +``` + +## πŸ“Š Supported Models +Llama-based models and Qwen-based models now are available + +## πŸŽ“ Citation + +```bibtex +@inproceedings{yao2025cacheblend, + title={CacheBlend: Fast large language model serving for RAG with cached knowledge fusion}, + author={Yao, Jiayi and Li, Hanchen and Liu, Yuhan and Ray, Siddhant and Cheng, Yihua and Zhang, Qizheng and Du, Kuntai and Lu, Shan and Jiang, Junchen}, + booktitle={Proceedings of the Twentieth European Conference on Computer Systems}, + pages={94--109}, + year={2025} +} +``` + + +--- + +
+ +**🌟 Star [UCM](https://github.com/ModelEngine-Group/unified-cache-management) repository if you find KvComp useful!** + +
diff --git a/examples/offline_inference_blend.py b/examples/offline_inference_blend.py new file mode 100644 index 000000000..0de105f55 --- /dev/null +++ b/examples/offline_inference_blend.py @@ -0,0 +1,273 @@ +import contextlib +import csv +import json +import os +import random +import re +import time +from dataclasses import asdict + +from tqdm import tqdm +from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Vector + +random.seed(0) + +import sys + +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.inputs import TokensPrompt + +from ucm.logger import init_logger + +logger = init_logger(__name__) + +model = "" +data_dir = "" +path_to_dataset = "" +tokenizer = None +# 28705 is the token id for char in llama model +# 151643 is the pad token id in qwen model +chunk_end_token_id = -1 +chunk_pad_token_id = -1 +block_size = 64 + + +def setup_environment_variables(): + os.environ["VLLM_USE_V1"] = "1" + os.environ["PYTHONHASHSEED"] = "123456" + + global model, data_dir, path_to_dataset, tokenizer, chunk_end_token_id, chunk_pad_token_id + model = os.getenv("MODEL_PATH", "/home/models/mistralai/Mistral-7B-Instruct-v0.2") + if not os.path.isdir(model): + model = input( + "Enter path to model, e.g./home/models/mistralai/Mistral-7B-Instruct-v0.2: " + ) + if not os.path.isdir(model): + print("Exiting. Incorrect model_path") + sys.exit(1) + + data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache") + if not os.path.isdir(data_dir): + data_dir = input( + "Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: " + ) + create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ") + if create.lower() == "y": + os.makedirs(data_dir, exist_ok=True) + else: + print("Exiting. Directory not created.") + sys.exit(1) + + # now support wikimqa + path_to_dataset = os.getenv( + "BLEND_DATASET_PATH", "/home/data/Longbench/data/2wikimqa.jsonl" + ) + if not os.path.isfile(path_to_dataset): + path_to_dataset = input( + "Enter path of one of 2wikimqa dataset in longbench, e.g. /home/data/Longbench/data/2wikimqa.jsonl: " + ) + if not os.path.isfile(path_to_dataset): + print("Exiting. Incorrect dataset path") + sys.exit(1) + + tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True) + # as for Qwen model, use pad_token_id for padding block + # as for Llama model, current use unk_token for padding block + chunk_pad_token_id = tokenizer.encode("▁", add_special_tokens=False)[0] + chunk_end_token_id = chunk_pad_token_id + + if tokenizer.pad_token_id is not None: + chunk_pad_token_id = tokenizer.pad_token_id + chunk_end_token_id = tokenizer.pad_token_id + + +@contextlib.contextmanager +def build_llm_with_uc(module_path: str, name: str, model: str): + ktc = KVTransferConfig( + kv_connector=name, + kv_connector_module_path=module_path, + kv_role="kv_both", + kv_connector_extra_config={ + "ucm_connectors": [ + { + "ucm_connector_name": "UcmNfsStore", + "ucm_connector_config": { + "storage_backends": data_dir, + "kv_block_size": 33554432, + }, + } + ], + "load_only_first_rank": False, + "ucm_sparse_config": { + "Blend": { + "chunk_end_token_id": chunk_end_token_id, + "compute_meta": { + "model.layers.1.self_attn.attn": { + "ratio": 0.2, + }, + }, + } + }, + "use_layerwise": True, + }, + ) + + llm_args = EngineArgs( + model=model, + enforce_eager=True, + kv_transfer_config=ktc, + max_model_len=16384 * 2, + max_num_batched_tokens=16384 * 2, + gpu_memory_utilization=0.8, + block_size=block_size, + enable_prefix_caching=False, + distributed_executor_backend="mp", + tensor_parallel_size=1, + trust_remote_code=True, + ) + + llm = LLM(**asdict(llm_args)) + try: + yield llm + finally: + logger.info("LLM engine is exiting.") + + +def get_output( + llm: LLM, + prompt, + sampling_params: SamplingParams, +): + start = time.time() + outputs = llm.generate(prompt, sampling_params) + print("-" * 50) + generated_text = None + for output in outputs: + generated_text = output.outputs[0].text + e2e_time = time.time() - start + print("-" * 50) + return e2e_time, generated_text + + +def pad_rag_chunks(token_ids, block_size, pad_id, end_id): + """ + pad token_ids with pad_id and end up with end_id + """ + # assert pad_id != end_id + remainder = len(token_ids) % block_size + + if remainder == 0 and token_ids[-1] in [pad_id, end_id]: + # no need to pad + token_ids[-1] = end_id + return token_ids + + pad_len = block_size - remainder - 1 + padded = token_ids + [pad_id] * pad_len + [end_id] + return padded + + +systemPrompt = "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n" + + +def main(): + module_path = "ucm.integration.vllm.blend_connector" + name = "UCMBlendConnector" + + setup_environment_variables() + + with build_llm_with_uc(module_path, name, model) as llm: + prefill_sampling_params = SamplingParams( + temperature=0.0, top_p=0.95, max_tokens=1 + ) + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=128) + # choose one data row in LongBenchV1 (wikimqa) + assert os.path.isfile( + path_to_dataset + ), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`" + with open(path_to_dataset, "r") as f: + lines = f.readlines() + dataset_row = json.loads(lines[0]) + + passages = re.findall( + r"Passage\s+(\d+):(.*?)(?=Passage\s+\d+:|$)", dataset_row["context"], re.S + ) + chunks = [f"Passage {i}:{passages[i][1]}" for i in range(len(passages))] + question = f"\n\nAnswer the question based on the given passages. Answer the question within 5 words. Do NOT repeat the question or output any other words. Question: {dataset_row["input"]}\nAnswer:" + origin_sys_prompt_ids = tokenizer.encode(systemPrompt) + padded_sys_prompt_ids = pad_rag_chunks( + origin_sys_prompt_ids, block_size, chunk_pad_token_id, chunk_end_token_id + ) + # 1. sys prompt warm up + print(f"---------------1. sys prompt: warm up---------------") + get_output( + llm, + TokensPrompt(prompt_token_ids=padded_sys_prompt_ids), + prefill_sampling_params, + ) + time.sleep(0.5) + + padded_contexts_ids = [] + padded_prompt_ids = padded_sys_prompt_ids + origin_prompt_ids = origin_sys_prompt_ids + for text_chunk in chunks: + un_pad_ids = tokenizer.encode(text_chunk, add_special_tokens=False) + padded_ids = pad_rag_chunks( + un_pad_ids, block_size, chunk_pad_token_id, chunk_end_token_id + ) + padded_prompt_ids = padded_prompt_ids + padded_ids + origin_prompt_ids = origin_prompt_ids + un_pad_ids + padded_contexts_ids.append(padded_ids) + + question_ids = tokenizer.encode(question, add_special_tokens=False) + padded_prompt_ids = padded_prompt_ids + question_ids + origin_prompt_ids = origin_prompt_ids + question_ids + + print(f"--------------- baseline with no cache blend ---------------") + baseline_time, baseline_gen_text = get_output( + llm, TokensPrompt(prompt_token_ids=origin_prompt_ids), sampling_params + ) + time.sleep(0.5) + + print(f"--------------- cache rag chunks ---------------") + llm.generate( + [TokensPrompt(prompt_token_ids=ids) for ids in padded_contexts_ids], + sampling_params, + ) + time.sleep(0.5) + + print(f"--------------- warm up blend code ---------------") + warm_up_blend_prompt_ids = padded_sys_prompt_ids + for ids in reversed(padded_contexts_ids): + warm_up_blend_prompt_ids = warm_up_blend_prompt_ids + ids + warm_up_blend_prompt_ids = warm_up_blend_prompt_ids + question_ids + llm.generate( + TokensPrompt(prompt_token_ids=warm_up_blend_prompt_ids), sampling_params + ) + time.sleep(0.5) + + print(f"--------------- cache blend ---------------") + blend_time, blend_gen_text = get_output( + llm, TokensPrompt(prompt_token_ids=padded_prompt_ids), sampling_params + ) + time.sleep(0.5) + + print(f"--------------- prefix cache ---------------") + pc_time, pc_gen_text = get_output( + llm, TokensPrompt(prompt_token_ids=origin_prompt_ids), sampling_params + ) + + print(f"Baseline generated text: {baseline_gen_text!r}") + print(f"Baseline generated cost time: {baseline_time:.2f} seconds") + print(f"Blend generated text: {blend_gen_text!r}") + print(f"Blend generated cost time: {blend_time:.2f} seconds") + print(f"Prefix Cache generated text: {pc_gen_text!r}") + print(f"Prefix Cache generated cost time: {pc_time:.2f} seconds") + print(f"Question:{dataset_row['input']}") + print(f"Golden answer:{dataset_row["answers"]}") + + +if __name__ == "__main__": + main() diff --git a/ucm/integration/vllm/blend_connector.py b/ucm/integration/vllm/blend_connector.py new file mode 100644 index 000000000..eaba1b381 --- /dev/null +++ b/ucm/integration/vllm/blend_connector.py @@ -0,0 +1,508 @@ +import itertools +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import TYPE_CHECKING, List, Self, Tuple + +import torch +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request + +from ucm.integration.vllm.ucm_connector import ( + RequestDispatchMeta, + RequestHasher, + RequestMeta, + UCMConnectorMetadata, + UCMDirectConnector, +) +from ucm.logger import init_logger +from ucm.shared.metrics import ucmmonitor +from ucm.sparse.blend.blockwise_rope import block_wise_rope_forward + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + +logger = init_logger(__name__) + + +@dataclass +class ChunkMetaData: + # [start, start + len) + start_token_dix: int + chunk_tokens_len: int + + start_blk_idx: int + chunk_blks_len: int + + cached_start_position: int + + vllm_blk_ids: List[int] = field(default_factory=list) + chunk_blks_hash: List[str] = field(default_factory=list) + store_hits: List[bool] = field(default_factory=list) + + @property + def end_token_dix(self) -> int: + return self.start_token_dix + self.chunk_tokens_len + + @property + def end_blk_idx(self) -> int: + return self.start_blk_idx + self.chunk_blks_len + + @property + def cached_end_position(self) -> int: + return self.cached_start_position + self.chunk_tokens_len + + @property + def position_offset(self) -> int: + return self.start_token_dix - self.cached_start_position + + @property + def hits_vllm_blk_ids(self) -> List[int]: + return list(itertools.compress(self.vllm_blk_ids, self.store_hits)) + + @property + def hits_chunk_blks_hash(self) -> List[str]: + return list(itertools.compress(self.chunk_blks_hash, self.store_hits)) + + def merge_chunk(self, temp_chunk_meta: Self) -> None: + # current we use a fix pattern(end with a fix token id) to recognize the text token chunk + # in some special situation, one text chunk maybe split as multi text chunk, so we should merge them into one + self.chunk_tokens_len += temp_chunk_meta.chunk_tokens_len + self.chunk_blks_len += temp_chunk_meta.chunk_blks_len + self.chunk_blks_hash += temp_chunk_meta.chunk_blks_hash + + def update_meta_partial_pc(self, num_pc_part_blks: int, block_size: int) -> None: + if num_pc_part_blks > 0: + self.start_token_dix += num_pc_part_blks * block_size + self.chunk_tokens_len -= num_pc_part_blks * block_size + + self.start_blk_idx += num_pc_part_blks + self.chunk_blks_len -= num_pc_part_blks + + self.chunk_blks_hash = self.chunk_blks_hash[num_pc_part_blks:] + self.store_hits = self.store_hits[num_pc_part_blks:] + self.cached_start_position += num_pc_part_blks * block_size + + +class BlendStage(Enum): + BUILD_CHUNK_CACHE = auto() + BUILD_PREFIX_CACHE = auto() + CACHE_BLEND = auto() + + def is_blend_cache(self) -> bool: + return self == BlendStage.CACHE_BLEND + + def is_prefix_cache(self) -> bool: + return self == BlendStage.BUILD_PREFIX_CACHE + + +@dataclass +class BlendRequestMeta: + ucm_block_hashs: list[str] = field(default_factory=list) + # hbm pc is not supported + hbm_hit_block_num: int = 0 + # ucm pc is supported + pc_hit_block_num: int = 0 + chunks_meta: List[ChunkMetaData] = field(default_factory=list) + blend_stage: BlendStage = BlendStage.BUILD_PREFIX_CACHE + + +@dataclass +class BlendRequestDispatchMeta(RequestDispatchMeta): + chunks_meta: List[ChunkMetaData] + + +@dataclass +class UCMBlendConnectorMetadata(UCMConnectorMetadata): + request_meta: dict[str, BlendRequestDispatchMeta] = field(default_factory=dict) + + +class UCMBlendConnector(UCMDirectConnector): + """ + This Connector process chunk hash and prefix cache + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config, role) + ucm_sparse_config = self.launch_config.get("ucm_sparse_config", []) + self.blend_stage = BlendStage.BUILD_PREFIX_CACHE + self.req2rag_load_chunks: dict[str, list[ChunkMetaData]] = {} + if "Blend" in ucm_sparse_config: + blend_config = ucm_sparse_config["Blend"] + self.enable_blend = True + self.chunk_end_token_id = blend_config["chunk_end_token_id"] + else: + raise "UCMBlendConnector init failed, please check your config" + + self.ucm_chunk_end_hash: int = self.request_hasher("UCM_CHUNK_END_HASH") + self.ucm_chunk_continue_hash: int = self.request_hasher( + "UCM_CHUNK_CONTINUE_HASH" + ) + self.requests_blend_meta: dict[str, BlendRequestMeta] = {} + self.cos_sin_cache: torch.Tensor = None + + # if chunk cache hits less than min_blend_threshold, no need to cache blend + self.min_blend_threshold = 16 + + def _generate_hash( + self, block_size: int, token_ids: list[int], parent_block_hash_value: int + ) -> list[str]: + ret = [] + for start in range(0, len(token_ids), block_size): + end = start + block_size + block_token_ids = token_ids[start:end] + # Do not hash the block if it is not full. + if len(block_token_ids) < block_size: + break + + block_token_ids_tuple = tuple(block_token_ids) + hash_value = self.request_hasher( + (parent_block_hash_value, block_token_ids_tuple) + ) + parent_block_hash_value = hash_value + ret.append(str(hash_value)) + + return ret + + def _process_req(self, all_token_ids: List[int]): + """ + pre-assumption, we explicitly construct block-padded chunk req to make it cached all tokens + beside chunk-build req, we try to split chunk from req, if no chunk exist, it just builds naive prefix cache + if chunk found, first we should match the prefix cache as much as possible, cause, they can be fully reused + then for other chunk blocks, if store hit num of block hash is less than threshold, we do not conduct cache blend + finally, if there are quite many chunk block-hits, we do cache blend to get TTFT-promot + """ + chunks_meta = [] + prefix_block_hashes = self._generate_hash( + self.block_size, all_token_ids, RequestHasher._SEED_HASH + ) + if ( + all_token_ids[-1] == self.chunk_end_token_id + and len(all_token_ids) % self.block_size == 0 + ): + return ( + BlendStage.BUILD_CHUNK_CACHE, + prefix_block_hashes, + chunks_meta, + [], + ) + + start_blk_idx = 0 + start_token_dix = 0 + req_chunks_hashes = [] + + for end_blk_idx, end_token_idx in enumerate( + range(self.block_size - 1, len(all_token_ids), self.block_size) + ): + # only compare the last token id in each blk to split chunk + # in future we should add chunk info as llm engine input,then pass them to schedule out + # but this will bring lots of modification to engine. + if all_token_ids[end_token_idx] == self.chunk_end_token_id: + chunk_token_ids = all_token_ids[start_token_dix : end_token_idx + 1] + chunk_blks_hash = self._generate_hash( + self.block_size, chunk_token_ids, RequestHasher._SEED_HASH + ) + + chunk_blks_len = end_blk_idx - start_blk_idx + 1 + chunk_tokens_len = chunk_blks_len * self.block_size + + rag_chunk_meta = ChunkMetaData( + start_token_dix=start_token_dix, + chunk_tokens_len=chunk_tokens_len, + start_blk_idx=start_blk_idx, + chunk_blks_len=chunk_blks_len, + chunk_blks_hash=chunk_blks_hash, + cached_start_position=0, + ) + + # update for next rag chunk + start_blk_idx = end_blk_idx + 1 + start_token_dix = end_token_idx + 1 + + chunks_meta.append(rag_chunk_meta) + req_chunks_hashes.extend(chunk_blks_hash) + + if chunks_meta: + # found chunk, as for suffix part(such as user question about chunk), current no need to cache hit and dump + return ( + BlendStage.CACHE_BLEND, + prefix_block_hashes, + chunks_meta, + req_chunks_hashes, + ) + else: + return ( + BlendStage.BUILD_PREFIX_CACHE, + prefix_block_hashes, + chunks_meta, + req_chunks_hashes, + ) + + def _get_req_chunk_hit( + self, + req_stage: BlendStage, + prefix_block_hashes: List[str], + req_chunks_meta: List[ChunkMetaData], + req_chunks_hashes: List[str], + ) -> Tuple[int, int]: + + # first perform prefix cache lookup + pc_lookup_results = self.store.lookup(prefix_block_hashes) + pc_hit_blocks = 0 + chunk_hit_blocks = 0 + + for i, hit in enumerate(pc_lookup_results): + if not hit: + break + pc_hit_blocks += 1 + + if not req_stage.is_blend_cache(): + return pc_hit_blocks, chunk_hit_blocks + + # then perform chunk cache lookup + chunk_lookup_results = self.store.lookup(req_chunks_hashes[pc_hit_blocks:]) + chunk_hit_blocks = sum(chunk_lookup_results) + + chunk_lookup_results = pc_lookup_results[:pc_hit_blocks] + chunk_lookup_results + # for cache blend + for i, chunk_meta in enumerate(req_chunks_meta): + chunk_meta.store_hits = chunk_lookup_results[ + chunk_meta.start_blk_idx : chunk_meta.end_blk_idx + ] + first_chunk_meta = req_chunks_meta[0] + first_chunk_meta.update_meta_partial_pc(pc_hit_blocks, self.block_size) + # remove total pc hit chunk + if first_chunk_meta.chunk_tokens_len == 0: + req_chunks_meta.pop(0) + + return pc_hit_blocks, chunk_hit_blocks + + def _generate_blend_dispatch_meta( + self, + req_meta: BlendRequestMeta, + new_tokens: int, + vllm_block_ids: list[int], + ) -> BlendRequestDispatchMeta: + """ + Request Blocks layout: + Stage: Build Prefix Cache or Build Chunk Cache (max one chunk per req) + ---------------------------------------------------------------------------------------------------------- + | prefix cache (at first chunk) | other chunk cache | + ---------------------------------------------------------------------------------------------------------- + | LOAD | DUMP | + ---------------------------------------------------------------------------------------------------------- + | REUSE | RECOMPUTE | + ---------------------------------------------------------------------------------------------------------- + + + Stage: Cache Blend + ---------------------------------------------------------------------------------------------------------- + | prefix cache at first chunk | other chunk cache hit | other chunk cache miss | suffix part(question) | + ---------------------------------------------------------------------------------------------------------- + | LOAD | LOAD | NO NEED TO DUMP | NO NEED TO DUMP | + ---------------------------------------------------------------------------------------------------------- + | REUSE | REUSE & RECOMPUTE | RECOMPUTE | RECOMPUTE | + ---------------------------------------------------------------------------------------------------------- + + """ + + # current not support chunk prefill, cause the topK high deviation KV should come from the all tokens + pc_hit_block_num = req_meta.pc_hit_block_num + ucm_block_hashs = req_meta.ucm_block_hashs + # load prefix part + load_ucm_block_ids, load_vllm_block_ids = ( + ucm_block_hashs[:pc_hit_block_num], + vllm_block_ids[:pc_hit_block_num], + ) + dump_ucm_block_ids, dump_vllm_block_ids = [], [] + + if req_meta.blend_stage.is_blend_cache(): + # just need to load, in future we may create a multi-chunk hash to dump and reuse the blended cache + for chunk_meta in req_meta.chunks_meta: + chunk_meta.vllm_blk_ids = vllm_block_ids[ + chunk_meta.start_blk_idx : chunk_meta.end_blk_idx + ] + load_ucm_block_ids.extend(chunk_meta.hits_chunk_blks_hash) + load_vllm_block_ids.extend(chunk_meta.hits_vllm_blk_ids) + return BlendRequestDispatchMeta( + (load_ucm_block_ids, load_vllm_block_ids), + (dump_ucm_block_ids, dump_vllm_block_ids), + req_meta.chunks_meta, + ) + + # build cache stage + dump_ucm_block_ids, dump_vllm_block_ids = ( + ucm_block_hashs[pc_hit_block_num:], + vllm_block_ids[pc_hit_block_num : len(ucm_block_hashs)], + ) + return BlendRequestDispatchMeta( + (load_ucm_block_ids, load_vllm_block_ids), + (dump_ucm_block_ids, dump_vllm_block_ids), + req_meta.chunks_meta, + ) + + def _post_process_chunk_cache(self, k_cache, vllm_ids, positions) -> None: + """ + post process loaded chunk kcache + """ + if self.cos_sin_cache is None: + raise "Please call setup model first." + # triton kernl for block-wise delta rope + block_wise_rope_forward(k_cache, vllm_ids, positions, self.cos_sin_cache) + + def _register_cos_sin_cache(self, model: "Model") -> None: + try: + rotary_emb = model.model.layers[0].self_attn.rotary_emb + self.cos_sin_cache = rotary_emb.cos_sin_cache + except Exception: + raise "get cos_sin_cache from model failed! current not implemented for this model" + + def setup_model(self, model: "Model") -> None: + self._register_cos_sin_cache(model) + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + + # current not support HBM prefix cache, cause the blended cached have a ground view of all chunks + # so they can not apply to other req + assert num_computed_tokens == 0 + all_token_ids = request.all_token_ids + + max_blk_num = len(all_token_ids) // self.block_size + + if max_blk_num == 0: + return 0, False + + req_stage, prefix_block_hashes, req_chunks_meta, req_chunks_hashes = ( + self._process_req(all_token_ids) + ) + + pc_hit_blocks, chunk_hit_blocks = self._get_req_chunk_hit( + req_stage, prefix_block_hashes, req_chunks_meta, req_chunks_hashes + ) + + if chunk_hit_blocks < self.min_blend_threshold: + req_stage = BlendStage.BUILD_PREFIX_CACHE + req_chunks_meta = [] + + req_block_hashes = prefix_block_hashes + if req_stage.is_blend_cache(): + req_block_hashes = req_chunks_hashes + + logger.info( + f"request_id: {request.request_id}, " + f"total_blocks_num: {max_blk_num}, " + f"req_stage: {req_stage}, " + f"first chunk prefix hit: {pc_hit_blocks}, " + f"chunks cache total hit: {chunk_hit_blocks}, " + ) + if self.metrics_config: + self.monitor.update_stats( + "ConnStats", + {"interval_lookup_hit_rates": chunk_hit_blocks / max_blk_num}, + ) + + pc_hit_tokens = pc_hit_blocks * self.block_size + + # When all the tokens are cached in ssd or hbm, + # we need to recompute the last token. This if condition will be removed + # once vLLM scheduler provides a better solution in the future. + if pc_hit_tokens == request.num_tokens: + pc_hit_tokens -= 1 + + self.requests_blend_meta[request.request_id] = BlendRequestMeta( + ucm_block_hashs=req_block_hashes, + pc_hit_block_num=pc_hit_blocks, + chunks_meta=req_chunks_meta, + blend_stage=req_stage, + ) + + return pc_hit_tokens, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + pass + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + requests_dispatch_meta = {} + # for new request, we need to load and dump + for request in scheduler_output.scheduled_new_reqs: + request_id, vllm_block_ids = request.req_id, request.block_ids[0] + req_meta = self.requests_blend_meta.get(request_id) + if req_meta: + requests_dispatch_meta[request_id] = self._generate_blend_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + vllm_block_ids, + ) + + # for cached request, there are 3 situation: + # 1. chunked prefill: we should make sure this will not happen + # 2. resumed: we need to handle like new request + # 3. TODO decode stage: nothing happened + scheduled_cached_reqs = scheduler_output.scheduled_cached_reqs + if not isinstance(scheduled_cached_reqs, list): + # >= 0.9.2 + for i, request_id in enumerate(scheduled_cached_reqs.req_ids): + if scheduler_output.num_scheduled_tokens[request_id] == 1: + # decode stage + continue + req_meta = self.requests_blend_meta.get(request_id) + if req_meta: + requests_dispatch_meta[request_id] = ( + self._generate_blend_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + scheduled_cached_reqs.new_block_ids[i][0], + ) + ) + else: + for request in scheduled_cached_reqs: + request_id = request.request_id + if scheduler_output.num_scheduled_tokens[request_id] == 1: + # decode stage + continue + req_meta = self.requests_blend_meta.get(request_id) + if req_meta: + requests_dispatch_meta[request_id] = ( + self._generate_blend_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + request.new_block_ids[0], + ) + ) + + # clear finished request + for request_id in scheduler_output.finished_req_ids: + self.requests_meta.pop(request_id, None) + + return UCMBlendConnectorMetadata(requests_dispatch_meta) + + def wait_for_layer_load(self, layer_name: str) -> None: + metadata = self._get_connector_metadata() + assert isinstance(metadata, UCMBlendConnectorMetadata) + + all_hits_vllm_ids = [] + positions = [] + k_cache = self.kv_caches[layer_name][0] + for request_id, request in metadata.request_meta.items(): + for chunk_meta in request.chunks_meta: + all_hits_vllm_ids.extend(chunk_meta.hits_vllm_blk_ids) + positions.extend( + [chunk_meta.position_offset] * len(chunk_meta.hits_vllm_blk_ids) + ) + if all_hits_vllm_ids: + vllm_ids = torch.tensor(all_hits_vllm_ids, device=k_cache.device) + positions = torch.tensor(positions, device=k_cache.device) + self._post_process_chunk_cache(k_cache, vllm_ids, positions) + pass diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py index f3927ece9..8d4a09af0 100644 --- a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py @@ -58,7 +58,7 @@ def _apply_ascend_patch() -> None: def _patch_attention_v1() -> None: """Patch attention_v1.py for vLLM-Ascend.""" try: - from typing import List + from typing import List, Optional import torch from vllm.forward_context import ForwardContext, get_forward_context @@ -72,15 +72,19 @@ def maybe_execute_sparse_attention_begin( value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + output: Optional[torch.Tensor] = None, + phase: Optional[str] = None, ): if not has_ucm_sparse(): - return + return query, key, value, output ucm_sparse = get_ucm_sparse() attn_metadata = forward_context.attn_metadata if attn_metadata is None: - return - ucm_sparse.attention_begin(query, key, value, layer_name, forward_context) + return query, key, value, output + return ucm_sparse.attention_begin( + query, key, value, layer_name, forward_context, output, phase + ) attention_v1.maybe_execute_sparse_attention_begin = ( maybe_execute_sparse_attention_begin @@ -139,7 +143,7 @@ def unified_ascend_attention_with_output_impl( self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] if not self.use_mla: - maybe_execute_sparse_attention_begin( + query, key, value, _ = maybe_execute_sparse_attention_begin( query, key, value, layer_name, forward_context ) self.impl.forward( @@ -386,13 +390,15 @@ def forward( # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy # TODO: use an elegant way to overlap - maybe_execute_sparse_attention_begin( - prefill_q, - prefill_k_c_normed, - prefill_k_pe, - layer.layer_name, - forward_context, - "prefill", + prefill_q, prefill_k_c_normed, prefill_k_pe, _ = ( + maybe_execute_sparse_attention_begin( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + layer.layer_name, + forward_context, + phase="prefill", + ) ) output_prefill = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata @@ -414,13 +420,15 @@ def forward( "prefill", ) if has_decode: - maybe_execute_sparse_attention_begin( - torch.cat([decode_ql_nope, decode_q_pe], dim=-1), - decode_ql_nope, - decode_q_pe, - layer.layer_name, - forward_context, - "decode", + _, decode_ql_nope, decode_q_pe, _ = ( + maybe_execute_sparse_attention_begin( + torch.cat([decode_ql_nope, decode_q_pe], dim=-1), + decode_ql_nope, + decode_q_pe, + layer.layer_name, + forward_context, + phase="decode", + ) ) if self.running_in_graph: return self._forward_decode( diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py index 2a697efb0..15bc8037b 100644 --- a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py @@ -49,6 +49,8 @@ def _apply_sparse_adapt() -> None: _patch_gpu_worker() _patch_scheduler_output() _patch_scheduler() + _patch_llama_model() + _patch_qwen_model() logger.info("UCM sparse adapt patches applied successfully") except Exception as e: logger.error(f"Could not apply sparse adapt patches: {e}") @@ -147,25 +149,103 @@ def _patch_attention_layer() -> None: from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse + def attn_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + # For some alternate attention backends like MLA the attention output + # shape does not match the query shape, so we optionally let the model + # definition specify the output tensor shape. + output_shape: Optional[torch.Size] = None, + ) -> torch.Tensor: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + + Attention metadata (`attn_metadata`) is set using a context manager in + the model runner's `execute_model` method. It is accessed via forward + context using + `vllm.forward_context.get_forward_context().attn_metadata`. + """ + if self.calculate_kv_scales: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.enable_kv_scales_calculation: + self.calc_kv_scales(query, key, value) + if self.use_output: + output_shape = output_shape if output_shape is not None else query.shape + output = torch.zeros( + output_shape, dtype=query.dtype, device=query.device + ) + hidden_size = output_shape[-1] + # We skip reshaping query, key and value tensors for the MLA + # backend since these tensors have different semantics and are + # processed differently. + if not self.use_mla: + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + if self.use_direct_call: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward( + self, + query, + key, + value, + self_kv_cache, + attn_metadata, + output=output, + ) + else: + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, self.layer_name + ) + return output.view(-1, hidden_size) + else: + if self.use_direct_call: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + return self.impl.forward( + self, query, key, value, self_kv_cache, attn_metadata + ) + else: + return torch.ops.vllm.unified_attention( + query, key, value, self.layer_name + ) + def maybe_execute_sparse_attention_begin( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + output: Optional[torch.Tensor] = None, phase: Optional[str] = None, ): if not has_ucm_sparse(): - return + return query, key, value, output ucm_sparse = get_ucm_sparse() attn_metadata = forward_context.attn_metadata if attn_metadata is None: - return + return query, key, value, output - ucm_sparse.attention_begin( - query, key, value, layer_name, forward_context, phase + return ucm_sparse.attention_begin( + query, key, value, layer_name, forward_context, output, phase ) def maybe_execute_sparse_attention_finished( @@ -221,7 +301,7 @@ def unified_attention_impl( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - maybe_execute_sparse_attention_begin( + query, key, value, _ = maybe_execute_sparse_attention_begin( query, key, value, layer_name, forward_context ) output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) @@ -247,8 +327,8 @@ def unified_attention_with_output_impl( self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] if not self.use_mla: - maybe_execute_sparse_attention_begin( - query, key, value, layer_name, forward_context + query, key, value, output = maybe_execute_sparse_attention_begin( + query, key, value, layer_name, forward_context, output ) self.impl.forward( self, @@ -281,6 +361,7 @@ def unified_attention_with_output_impl( layer.maybe_execute_sparse_attention_finished = ( maybe_execute_sparse_attention_finished ) + layer.Attention.forward = attn_forward layer.unified_attention = unified_attention_impl layer.unified_attention_with_output = unified_attention_with_output_impl @@ -412,13 +493,15 @@ def forward( ) if has_prefill: - maybe_execute_sparse_attention_begin( - prefill_q, - prefill_k_c_normed, - prefill_k_pe, - layer.layer_name, - forward_context, - "prefill", + prefill_q, prefill_k_c_normed, prefill_k_pe, _ = ( + maybe_execute_sparse_attention_begin( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + layer.layer_name, + forward_context, + phase="prefill", + ) ) output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata @@ -443,13 +526,15 @@ def forward( decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) - maybe_execute_sparse_attention_begin( - torch.cat([decode_ql_nope, decode_q_pe], dim=-1), - decode_ql_nope, - decode_q_pe, - layer.layer_name, - forward_context, - "decode", + _, decode_ql_nope, decode_q_pe, _ = ( + maybe_execute_sparse_attention_begin( + torch.cat([decode_ql_nope, decode_q_pe], dim=-1), + decode_ql_nope, + decode_q_pe, + layer.layer_name, + forward_context, + phase="decode", + ) ) output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata @@ -1155,17 +1240,24 @@ def maybe_execute_ucm_sparse_begin( ): if not has_ucm_sparse(): return + + if has_kv_transfer_group(): + uc_connector = get_kv_transfer_group() + uc_setup_model = getattr(uc_connector, "setup_model", None) + if callable(uc_setup_model): + uc_setup_model(self.model) + ucm_sparse = get_ucm_sparse() ucm_sparse.build_sparse_meta( scheduler_output, self.requests, self.input_batch, attn_metadata ) ucm_sparse.execute_begin(scheduler_output) - def maybe_execute_ucm_sparse_finished(self): + def maybe_execute_ucm_sparse_finished(self, logits_indices): if not has_ucm_sparse(): - return + return logits_indices ucm_sparse = get_ucm_sparse() - ucm_sparse.execute_finished() + return ucm_sparse.execute_finished(logits_indices) def ucm_sparse_request_finished_in_worker(self, request_id: str | int): if not has_ucm_sparse(): @@ -1749,7 +1841,7 @@ def execute_model( ) self.maybe_wait_for_kv_save() - self.maybe_execute_ucm_sparse_finished() + logits_indices = self.maybe_execute_ucm_sparse_finished(logits_indices) finished_sending, finished_recving = self.get_finished_kv_transfers( scheduler_output @@ -1990,3 +2082,235 @@ def patched_init_worker_distributed_environment( ) except ImportError: logger.warning("Could not patch gpu worker - module not found") + + +# ==================== vllm/model_executor/models/llama.py ==================== +def _patch_llama_model() -> None: + """Patch gpu worker to add UCM sparse support.""" + try: + from typing import Optional, Union + + import torch + from vllm.config import VllmConfig + from vllm.distributed import get_pp_group + from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaModel + from vllm.sequence import IntermediateTensors + + from ucm.sparse.state import ( + get_ucm_sparse, + has_ucm_sparse, + maybe_execute_sparse_ffn_begin, + maybe_execute_sparse_ffn_finished, + maybe_execute_sparse_layer_begin, + maybe_execute_sparse_layer_finished, + ) + + def llamaDecoderLayer_forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, hidden_states=hidden_states + ) + ###################### + ### UCM PATCH START### + hidden_states, residual = maybe_execute_sparse_ffn_begin( + hidden_states, residual + ) + ### UCM PATCH END ### + ###################### + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + hidden_states = self.mlp(hidden_states) + ###################### + ### UCM PATCH START### + hidden_states, residual = maybe_execute_sparse_ffn_finished( + hidden_states, residual + ) + ### UCM PATCH END ### + ###################### + return hidden_states, residual + + LlamaDecoderLayer.forward = llamaDecoderLayer_forward + + def llamaModel_forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states = [] + for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): + ###################### + ### UCM PATCH START### + positions, hidden_states, residual = maybe_execute_sparse_layer_begin( + positions, hidden_states, residual + ) + ### UCM PATCH END ### + ###################### + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer(positions, hidden_states, residual) + ###################### + ### UCM PATCH START### + positions, hidden_states, residual = ( + maybe_execute_sparse_layer_finished( + positions, hidden_states, residual + ) + ) + ### UCM PATCH END ### + ###################### + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states + + LlamaModel.forward = llamaModel_forward + + except ImportError: + logger.warning("Could not patch llama modelr - module not found") + + +# ==================== vllm/model_executor/models/qwen2.py ==================== +def _patch_qwen_model() -> None: + """Patch gpu worker to add UCM sparse support.""" + try: + from typing import Optional, Union + + import torch + from vllm.config import VllmConfig + from vllm.distributed import get_pp_group + from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model + from vllm.sequence import IntermediateTensors + + from ucm.sparse.state import ( + get_ucm_sparse, + has_ucm_sparse, + maybe_execute_sparse_ffn_begin, + maybe_execute_sparse_ffn_finished, + maybe_execute_sparse_layer_begin, + maybe_execute_sparse_layer_finished, + ) + + def qwen2DecoderLayer_forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + ###################### + ### UCM PATCH START### + residual, hidden_states = maybe_execute_sparse_ffn_begin( + residual, hidden_states + ) + ### UCM PATCH END ### + ###################### + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + hidden_states = self.mlp(hidden_states) + ###################### + ### UCM PATCH START### + residual, hidden_states = maybe_execute_sparse_ffn_finished( + residual, hidden_states + ) + ### UCM PATCH END ### + ###################### + return hidden_states, residual + + Qwen2DecoderLayer.forward = qwen2DecoderLayer_forward + + def qwen2Model_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer : self.end_layer]: + ###################### + ### UCM PATCH START### + positions, hidden_states, residual = maybe_execute_sparse_layer_begin( + positions, + hidden_states, + residual, + ) + ### UCM PATCH END ### + ###################### + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + ###################### + ### UCM PATCH START### + positions, hidden_states, residual = ( + maybe_execute_sparse_layer_finished( + positions, hidden_states, residual + ) + ) + ### UCM PATCH END ### + ###################### + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + Qwen2Model.forward = qwen2Model_forward + + except ImportError: + logger.warning("Could not patch llama modelr - module not found") diff --git a/ucm/sparse/base.py b/ucm/sparse/base.py index ed62ab30c..7dc27fb46 100644 --- a/ucm/sparse/base.py +++ b/ucm/sparse/base.py @@ -23,7 +23,7 @@ import enum from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -117,11 +117,11 @@ def execute_begin(self, scheduler_output: SchedulerOutput): """ pass - def execute_finished(self): + def execute_finished(self, logits_indices: torch.Tensor) -> torch.Tensor: """ This is called at the end of "ModelRunner->execute_model" function. """ - pass + return logits_indices def attention_begin( self, @@ -130,14 +130,15 @@ def attention_begin( value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + output: Optional[torch.Tensor] = None, phase: Optional[str] = None, - ) -> None: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ This is called at the beginning of "unified_attention". Sparse attention algorithm can modify forward_context.attn_metadata if necessary. (UC_TODO: modify dataclass is not allowed in python?) """ - pass + return query, key, value, output def attention_finished( self, @@ -154,6 +155,44 @@ def attention_finished( """ pass + def ffn_begin( + self, hidden_states: torch.Tensor, residual: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This is called at the beginning of ffn in each DecodeLayer. + """ + return hidden_states, residual + + def ffn_finished( + self, hidden_states: torch.Tensor, residual: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This is called at the end of ffn in each DecodeLayer. + """ + return hidden_states, residual + + def layer_begin( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This is called at the beginning of DecodeLayer. + """ + return positions, hidden_states, residual + + def layer_finished( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This is called at the end of DecodeLayer. + """ + return positions, hidden_states, residual + def request_finished_in_worker(self, request_id: Union[int, str]): """ This function releases the resources of finished requests at worker-side. diff --git a/ucm/sparse/blend/__init__.py b/ucm/sparse/blend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ucm/sparse/blend/blend.py b/ucm/sparse/blend/blend.py new file mode 100644 index 000000000..7cc945674 --- /dev/null +++ b/ucm/sparse/blend/blend.py @@ -0,0 +1,362 @@ +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +from sympy import false +from torch import Tensor + +from ucm.logger import init_logger +from ucm.store.dramstore.dramstore_connector import device + +logger = init_logger(__name__) + +from vllm.config import VllmConfig +from vllm.forward_context import ForwardContext +from vllm.v1.request import Request + +from ucm.integration.vllm.blend_connector import BlendRequestDispatchMeta +from ucm.sparse.base import ( + INVALID_SLOT, + UcmSparseBase, + UcmSparseMetadata, + UcmSparseRole, +) +from ucm.sparse.utils import round_up + + +def get_num_blks(num_tokens, block_size): + return (num_tokens + block_size - 1) // block_size + + +@dataclass +class ReqMeta: + req_idx: int = 0 + need_blend: bool = false + + prefix_len: int = 0 + prefix_blk_len: int = 0 + + chunks_len: int = 0 + chunks_blk_len: int = 0 + + suffix_len: int = 0 + suffix_blk_len: int = 0 + + chunk_hit_mask: List[bool] = field(default_factory=list) + + chunk_hit_blk_len: int = 0 + + +@dataclass +class BlendMetaData(UcmSparseMetadata): + requests: list[ReqMeta] = field(default_factory=list) + compute_mask: Tensor = None + chunk_blks_hit_mask: Tensor = None + query_lens: Tensor = None + blend_start_req_idx: int = 0 + need_re_index: bool = False + + def reset_blend_meta(self, forward_mask, attn_metadata, scheduler_output): + # current not support chunk prefill + # for multi req in one batch, we should discard the decode req + self.need_re_index = False + self.requests = [] + self.compute_mask = forward_mask[: attn_metadata.query_start_loc[-1]] + + self.query_lens = ( + attn_metadata.query_start_loc[1:] - attn_metadata.query_start_loc[:-1] + ) + + self.blend_start_req_idx = len(scheduler_output.scheduled_cached_reqs.req_ids) + + def add_request( + self, + idx: int, + req_dispatch_meta: BlendRequestDispatchMeta, + seq_lens: Tensor, + block_size: int, + ) -> None: + chunks_meta = req_dispatch_meta.chunks_meta + if chunks_meta: + hit_mask = [] + req_idx_batch = self.blend_start_req_idx + idx + for meta in chunks_meta: + hit_mask.extend(meta.store_hits) + reqMeta = ReqMeta( + req_idx=req_idx_batch, + prefix_len=chunks_meta[0].start_token_dix, + prefix_blk_len=get_num_blks(chunks_meta[0].start_token_dix, block_size), + chunks_len=len(hit_mask) * block_size, + chunks_blk_len=len(hit_mask), + chunk_hit_mask=hit_mask, + chunk_hit_blk_len=sum(hit_mask), + ) + reqMeta.need_blend = reqMeta.chunk_hit_blk_len > 0 + reqMeta.suffix_len = ( + seq_lens[req_idx_batch].item() - reqMeta.prefix_len - reqMeta.chunks_len + ) + reqMeta.suffix_blk_len = get_num_blks(reqMeta.suffix_len, block_size) + + self.requests.append(reqMeta) + + def reset_compute_mask(self) -> None: + self.compute_mask.fill_(False) + # for decode req in the front of the batch + self.compute_mask[: self.blend_start_req_idx] = True + + def update_query_lens(self, req_idx: int, reused_num_tokens: int) -> None: + self.query_lens[req_idx] -= reused_num_tokens + + def update_need_re_index(self, need_re_index: bool) -> None: + self.need_re_index = need_re_index + + def update_req_compute_mask( + self, + req_query_start, + req_chunk_end, + req_query_end, + chunk_hit_mask, + top_k_indices, + ): + # for multi req batch, maybe we should update compute_mask in batch level rather than in req level + chunks = self.compute_mask[req_query_start:req_chunk_end] + chunks = chunks.reshape(len(chunk_hit_mask), -1) + + # for chunk block cache miss part, just recompute + chunks.masked_fill_(~chunk_hit_mask.unsqueeze(1), True) + + flat = chunks.view(-1) + # for chunk block cache hit part, just recompute HKVD(highest KV deviation) tokens + flat[top_k_indices] = True + + # for question part, default + self.compute_mask[req_chunk_end:req_query_end].fill_(True) + + +class Blend(UcmSparseBase): + def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): + super().__init__(vllm_config, role) + self.blend_config = vllm_config.kv_transfer_config.kv_connector_extra_config[ + "ucm_sparse_config" + ]["Blend"] + + max_model_len = vllm_config.model_config.max_model_len + self.block_size = vllm_config.cache_config.block_size + + self.device = vllm_config.device_config.device + self.forward_mask = torch.zeros(max_model_len, device=self.device).bool() + self.mask_idx = torch.arange( + round_up(max_model_len, self.block_size), device=self.device + ) + self.mask_idx = self.mask_idx.reshape(-1, self.block_size) + + # for multi batch, ignore the decode-stage req at the beginning + self.blend_start_req_idx = 0 + + self.compute_meta = self.blend_config["compute_meta"] + self.blend_req_metas: BlendMetaData = BlendMetaData( + need_re_index=False, + chunk_blks_hit_mask=torch.zeros( + round_up(max_model_len, self.block_size), device=self.device + ).bool(), + ) + self.attn_metadata = None + + def build_sparse_meta( + self, scheduler_output, requests, input_batch, attn_metadata + ) -> UcmSparseMetadata: + + if isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values())) + self.attn_metadata = attn_metadata + + self.blend_req_metas.reset_blend_meta( + self.forward_mask, attn_metadata, scheduler_output + ) + + blend_conn_request_meta = scheduler_output.kv_connector_metadata.request_meta + for idx, request in enumerate(scheduler_output.scheduled_new_reqs): + req_id = request.req_id + self.blend_req_metas.add_request( + idx, + blend_conn_request_meta[req_id], + attn_metadata.seq_lens, + self.block_size, + ) + + return self.blend_req_metas + + def _update_attn_metadata(self): + # update attn_metadata, cause we sparse the prefill tokens + self.attn_metadata.slot_mapping = self.attn_metadata.slot_mapping[ + self.blend_req_metas.compute_mask + ] + self.attn_metadata.query_start_loc[1:] = torch.cumsum( + self.blend_req_metas.query_lens, dim=0 + ) + self.attn_metadata.max_query_len = self.blend_req_metas.query_lens.max().item() + self.attn_metadata.num_actual_tokens = ( + self.blend_req_metas.query_lens.sum().item() + ) + + def estimate_num_slots_sparsed(self, request: Request) -> int: + """ + This is called by "Scheduler->schedule" function to estimate the number of required blocks. + """ + return INVALID_SLOT + + def request_begin(self, request_id: Union[int, str], prompt_token_ids: List[int]): + pass + + def attention_begin( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, + forward_context: ForwardContext, + output: Optional[torch.Tensor] = None, + phase: Optional[str] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + attn = forward_context.no_compile_layers[layer_name] + kv_cache = attn.kv_cache[forward_context.virtual_engine] + start_time = time.perf_counter() + if layer_name in self.compute_meta.keys(): + need_update = False + self.blend_req_metas.reset_compute_mask() + + # maybe we can use triton kernel + for req_meta in self.blend_req_metas.requests: + req_idx = req_meta.req_idx + req_query_start = self.attn_metadata.query_start_loc[req_idx].item() + req_query_end = self.attn_metadata.query_start_loc[req_idx + 1].item() + + if not req_meta.need_blend: + self.blend_req_metas.compute_mask[ + req_query_start:req_query_end + ].fill_(True) + continue + req_chunk_end = req_query_start + req_meta.chunks_len + + # HBM prefix cache is not supported now + # UC store prefix cache can be fully reused for the first chunk + his_vllm_blk_ids = self.attn_metadata.block_table[req_idx][ + req_meta.prefix_blk_len : req_meta.prefix_blk_len + + req_meta.chunks_blk_len + ] + # only compute topk of chunk's hits block + chunk_hit_mask = self.blend_req_metas.chunk_blks_hit_mask[ + : len(req_meta.chunk_hit_mask) + ] + src = torch.as_tensor( + req_meta.chunk_hit_mask, + dtype=chunk_hit_mask.dtype, + device=chunk_hit_mask.device, + ) + chunk_hit_mask.copy_(src) + + his_vllm_blk_ids = his_vllm_blk_ids[chunk_hit_mask] + his_k = kv_cache[0, his_vllm_blk_ids] + candidate_len = req_meta.chunk_hit_blk_len * self.block_size + his_k = his_k.reshape(candidate_len, -1) + + req_key = key[req_query_start:req_chunk_end] + + # req_key does not contain prefix cache + golden_k = req_key.reshape( + req_meta.chunks_blk_len, self.block_size, -1 + )[chunk_hit_mask] + golden_k = golden_k.reshape(candidate_len, -1) + + diff_k = torch.sum((his_k - golden_k).abs(), dim=[1]) + topK_num = int(candidate_len * self.compute_meta[layer_name]["ratio"]) + + topK_indices = torch.topk(diff_k, k=topK_num).indices + + # get origin idx in req_key + topK_indices = self.mask_idx[: req_meta.chunks_blk_len][ + chunk_hit_mask + ].reshape(-1)[topK_indices] + + # update compute_mask + self.blend_req_metas.update_req_compute_mask( + req_query_start, + req_chunk_end, + req_query_end, + chunk_hit_mask, + topK_indices, + ) + + self.blend_req_metas.update_query_lens( + req_idx, candidate_len - topK_num + ) + need_update = True + + if need_update: + logger.info( + f"[blend-attn] compute_mask time: {(time.perf_counter() - start_time) * 1000}ms" + ) + self.blend_req_metas.update_need_re_index(True) + self._update_attn_metadata() + + indexed_query = query[self.blend_req_metas.compute_mask] + indexed_key = key[self.blend_req_metas.compute_mask] + indexed_value = value[self.blend_req_metas.compute_mask] + indexed_output = None + if output is not None: + indexed_output = output[: self.blend_req_metas.compute_mask.sum()] + logger.info( + f"[blend-attn] compute_mask time + index time: {(time.perf_counter() - start_time) * 1000}ms" + ) + logger.info( + f"[blend-attn] reduce attn tokens from {len(self.blend_req_metas.compute_mask)} " + f"to {self.attn_metadata.num_actual_tokens}" + ) + return indexed_query, indexed_key, indexed_value, indexed_output + return query, key, value, output + + def ffn_begin( + self, hidden_states: torch.Tensor, residual: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # hidden_states is equal to attn out, which is contiguous + if self.blend_req_metas.need_re_index and len( + self.blend_req_metas.compute_mask + ) == len(residual): + logger.info( + f"[blend-ffn] after cache blend, reduce ffn tokens from {len(self.blend_req_metas.compute_mask)} " + f"to {self.blend_req_metas.compute_mask.sum().item()}" + ) + return hidden_states[ + : self.attn_metadata.num_actual_tokens + ], self._index_tensor(residual) + return hidden_states, residual + + def layer_begin( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if len(positions) != len(hidden_states): + logger.info( + f"[blend-layer] after cache blend, reduce layer tokens from {len(self.blend_req_metas.compute_mask)} " + f"to {self.blend_req_metas.compute_mask.sum().item()}" + ) + return self._index_tensor(positions), hidden_states, residual + return positions, hidden_states, residual + + def execute_finished(self, logits_indices: torch.Tensor): + if self.blend_req_metas.need_re_index: + modified_logits_indices = self.attn_metadata.query_start_loc[1:] - 1 + logger.info( + f"[blend-model] modify logits_indices from {logits_indices} " + f"to {modified_logits_indices}" + ) + return modified_logits_indices + return logits_indices + + def _index_tensor(self, tensor: torch.Tensor): + if self.blend_req_metas.need_re_index: + return tensor[self.blend_req_metas.compute_mask] + return tensor diff --git a/ucm/sparse/blend/blockwise_rope.py b/ucm/sparse/blend/blockwise_rope.py new file mode 100644 index 000000000..4fea30733 --- /dev/null +++ b/ucm/sparse/blend/blockwise_rope.py @@ -0,0 +1,224 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _triton_rope_blockwise_kernel( + k_ptr, # (total_blocks, seq_len, n_kv_head, hd) + vllm_ids, # (bs,) block id for each batch + positions, # (bs,) delta angle for each batch + cos_sin_cache, # (1, seq_len, hd) + k_row_stride, + k_head_stride, + cos_sin_row_stride, + sl, + bs: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_hd: tl.constexpr, +): + """ + each program/batch process a single head for each token + programs matrix (batch_idx, seq_idx, head_idx) + """ + pid = tl.program_id(0) + + heads_per_seq = n_kh + tokens_per_batch = sl * n_kh + batch_idx = pid // tokens_per_batch + seq_head_idx = pid % tokens_per_batch + seq_idx = seq_head_idx // n_kh + head_idx = seq_head_idx % n_kh + + # block id & position + block_id = tl.load(vllm_ids + batch_idx) + pos_idx = tl.load(positions + batch_idx) + + # k offset + k_offset = block_id * k_row_stride + seq_idx * (n_kh * hd) + head_idx * hd + k_ptr = k_ptr + k_offset + + # fetch cos sin from cos_sin_cache + cos_base = pos_idx * cos_sin_row_stride + sin_base = cos_base + hd // 2 # sin just behind cos + + offs = tl.arange(0, pad_hd // 2) + mask = offs < hd // 2 + + cos_row = tl.load(cos_sin_cache + cos_base + offs, mask=mask, other=0) + sin_row = tl.load(cos_sin_cache + sin_base + offs, mask=mask, other=0) + + k_tile_1 = tl.load(k_ptr + offs, mask=mask, other=0).to(cos_row.dtype) + k_tile_2 = tl.load(k_ptr + offs + hd // 2, mask=mask, other=0).to(cos_row.dtype) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + + tl.store(k_ptr + offs, new_k_tile_1, mask=mask) + tl.store(k_ptr + offs + hd // 2, new_k_tile_2, mask=mask) + + +def block_wise_rope_forward( + k_cache: torch.Tensor, + vllm_ids: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, +) -> torch.Tensor: + """ + Args: + k_cache: torch.Tensor (total_blocks, seq_len, n_kv_heads, hd), vllm owned. + vllm_ids: torch.LongTensor (batch_size,), vllm block id + positions: torch.LongTensor (batch_size,), delta angle of each block for rope + cos_sin_cache: torch.Tensor (1, seq_len, hd),same as the tensor in rotary_emb + """ + total_blocks, seq_len, n_kv_head, head_dim = k_cache.shape + batch_size = vllm_ids.shape[0] + pad_hd = triton.next_power_of_2(head_dim) + + k_cache = k_cache.contiguous() + vllm_ids = vllm_ids.contiguous() + positions = positions.contiguous() + cos_sin_cache = cos_sin_cache.contiguous() + + n_row = batch_size * seq_len * n_kv_head + + _triton_rope_blockwise_kernel[(n_row,)]( + k_cache, + vllm_ids, + positions, + cos_sin_cache, + k_cache.stride(0), + k_cache.stride(-2), + cos_sin_cache.stride(-2), + seq_len, + batch_size, + n_kv_head, + head_dim, + pad_hd, + ) + + return k_cache + + +def rope_naive_torch( + k_cache: torch.Tensor, + vllm_ids: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, +) -> torch.Tensor: + """ + naive torch implementation for accuracy and perf baseline + Args: + k_cache: (total_blocks, seq_len, n_heads, hd) + vllm_ids: (bs,) + positions: (bs,) + cos_sin_cache: (1, seq_len, hd) + Returns: + rotated_k: same shape as k_cache + """ + total_blocks, sl, nh, hd = k_cache.shape + bs = vllm_ids.shape[0] + + # copy to avoid in-place modifying original + k_out = k_cache.clone() + + half = hd // 2 + + # cos_sin_cache shape: (1, seq_len, hd) + cos_sin_cache = cos_sin_cache.squeeze(0) # (sl, hd) + cos_table = cos_sin_cache[:, :half] # (sl, half) + sin_table = cos_sin_cache[:, half:] # (sl, half) + + # Loop in python (slow but clear) + for b in range(bs): + blk = vllm_ids[b].item() + pos = positions[b].item() # rope offset + + for s in range(sl): + # cos, sin row for this position + cos = cos_table[pos] # (half,) + sin = sin_table[pos] + + for h in range(nh): + # read original k + k_vec = k_out[blk, s, h] # (hd,) + k1 = k_vec[:half] # (half,) + k2 = k_vec[half:] # (half,) + + # rope rotate + new_k1 = k1 * cos - k2 * sin + new_k2 = k2 * cos + k1 * sin + + # write back + k_out[blk, s, h, :half] = new_k1 + k_out[blk, s, h, half:] = new_k2 + + return k_out + + +if __name__ == "__main__": + import time + + torch.manual_seed(42) + + total_blocks = 5120 + num_blocks = 128 + block_size = 128 + max_num_tokens = num_blocks * block_size + num_heads = 8 + head_size = 128 + dtype = torch.bfloat16 + + kcache = torch.randn( + total_blocks, block_size, num_heads, head_size, device="cuda", dtype=dtype + ) + vllm_ids = torch.randint( + 0, total_blocks, (num_blocks,), device="cuda", dtype=torch.long + ) + positions = torch.randint( + 0, max_num_tokens, (num_blocks,), device="cuda", dtype=torch.long + ) + cos_sin_cache = torch.randn(max_num_tokens, head_size, device="cuda", dtype=dtype) + + # naive torch result + baseline_rope_kcache = rope_naive_torch(kcache, vllm_ids, positions, cos_sin_cache) + + triton_rope_kcache = block_wise_rope_forward( + kcache, vllm_ids, positions, cos_sin_cache + ) + + # precision compare + diff = (triton_rope_kcache[vllm_ids] - baseline_rope_kcache[vllm_ids]).abs() + mean_err = diff.mean().item() + print(f"MAE : {mean_err:.6f}. Expected 1e-3") + + def bench(fn, n_iter=50): + torch.cuda.synchronize() + t0 = time.time() + for _ in range(n_iter): + fn() + torch.cuda.synchronize() + dt = (time.time() - t0) / n_iter + return dt * 1e3 # ms + + ms = bench( + lambda: block_wise_rope_forward(kcache, vllm_ids, positions, cos_sin_cache) + ) + print(f"Kernel avg latency: {ms:.3f} ms. Expected 100 us") + + # load K,load cos,sin -> dump K + bytes_total = ( + num_blocks + * block_size + * num_heads + * ( + head_size * kcache.dtype.itemsize # K load + + vllm_ids.dtype.itemsize # vllm_ids load + + positions.dtype.itemsize # positions load + + head_size * cos_sin_cache.dtype.itemsize # cos sin load + + head_size * kcache.dtype.itemsize # K dump + ) + ) + bw = bytes_total / (ms / 1e3) / (1024**3) + print(f"Estimated memory BW: {bw:.1f} GiB/s") diff --git a/ucm/sparse/cache_blend/README.md b/ucm/sparse/cache_blend/README.md deleted file mode 100644 index 8b1378917..000000000 --- a/ucm/sparse/cache_blend/README.md +++ /dev/null @@ -1 +0,0 @@ - diff --git a/ucm/sparse/esa/esa.py b/ucm/sparse/esa/esa.py index ac36e54c5..97cf921be 100644 --- a/ucm/sparse/esa/esa.py +++ b/ucm/sparse/esa/esa.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass from functools import cache -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -554,8 +554,9 @@ def attention_begin( value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + output: Optional[torch.Tensor] = None, phase: Optional[str] = None, - ) -> None: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if not self.is_mla: for req_meta in self._sparse_metadata.requests: self.create_req_state_attention_begin( @@ -573,6 +574,8 @@ def attention_begin( req_meta, layer_name, query, key, value, forward_context ) + return query, key, value, output + def update_req_state_attention_end( self, req_meta, layer_name, query, key, value, attn_output, forward_context ): diff --git a/ucm/sparse/factory.py b/ucm/sparse/factory.py index d5b49cf37..da0f6b6e4 100644 --- a/ucm/sparse/factory.py +++ b/ucm/sparse/factory.py @@ -51,3 +51,4 @@ def create_sparse_method( UcmSparseFactory.register_sparse_method( "KVStarMultiStep", "ucm.sparse.kvstar.multistep", "KVStarMultiStep" ) +UcmSparseFactory.register_sparse_method("Blend", "ucm.sparse.blend.blend", "Blend") diff --git a/ucm/sparse/gsa/gsa.py b/ucm/sparse/gsa/gsa.py index b1bf1e5c1..20aa7a69b 100644 --- a/ucm/sparse/gsa/gsa.py +++ b/ucm/sparse/gsa/gsa.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from functools import cache, wraps from itertools import accumulate -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch from vllm.config import VllmConfig @@ -631,8 +631,9 @@ def attention_begin( value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + output: Optional[torch.Tensor] = None, phase: Optional[str] = None, - ) -> None: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: current_layer_id = int(layer_name.split(".")[2]) if self.prefetch_engine.atb_gsa_enable and self.prefetch_engine.is_topk_cal: if not self.use_mla: @@ -683,6 +684,8 @@ def attention_begin( ] ) + return query, key, value, output + def attention_finished( self, query: torch.Tensor, @@ -901,7 +904,7 @@ def execute_begin(self, scheduler_output: SchedulerOutput): self.gsa_stats = self.gsa_metadata.gsa_stats self._start_topk_cal() - def execute_finished(self): + def execute_finished(self, logits_indices: torch.Tensor): kv_caches = [None] * self.layer_num forward_context = get_forward_context() attn = forward_context.no_compile_layers @@ -930,6 +933,7 @@ def execute_finished(self): self.prefetch_engine.deal_async_prefetch( False, self.gsa_metadata, kv_caches, None ) + return logits_indices def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches): if all_free_block_ids == None: diff --git a/ucm/sparse/kvstar/multistep.py b/ucm/sparse/kvstar/multistep.py index 18ed4cb87..7467d1936 100644 --- a/ucm/sparse/kvstar/multistep.py +++ b/ucm/sparse/kvstar/multistep.py @@ -1,7 +1,7 @@ import enum import math from dataclasses import dataclass, field -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch from vllm.config import VllmConfig @@ -731,8 +731,9 @@ def attention_begin( value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + output: Optional[torch.Tensor] = None, phase: Optional[str] = None, - ) -> None: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ This is called at the beginning of "unified_attention". Sparse attention algorithm can modify forward_context.attn_metadata if necessary. @@ -746,6 +747,8 @@ def attention_begin( req_layerwise_state.update_meta(req_meta, forward_context) req_layerwise_state.attention_begin(query, key, value, forward_context) + return query, key, value, output + def attention_finished( self, query: torch.Tensor, diff --git a/ucm/sparse/state.py b/ucm/sparse/state.py index a0f77a53b..36cd93909 100644 --- a/ucm/sparse/state.py +++ b/ucm/sparse/state.py @@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, Optional +import torch + from ucm.logger import init_logger from ucm.sparse.base import UcmSparseBase, UcmSparseRole from ucm.sparse.factory import UcmSparseFactory @@ -72,3 +74,37 @@ def has_ucm_sparse() -> bool: """Check if UCM sparse agent is available.""" global _UCM_SPARSE_AGENT return _UCM_SPARSE_AGENT is not None + + +def maybe_execute_sparse_layer_begin( + positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor +): + if not has_ucm_sparse(): + return positions, hidden_states, residual + ucm_spare = get_ucm_sparse() + return ucm_spare.layer_begin(positions, hidden_states, residual) + + +def maybe_execute_sparse_layer_finished( + positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor +): + if not has_ucm_sparse(): + return positions, hidden_states, residual + ucm_spare = get_ucm_sparse() + return ucm_spare.layer_finished(positions, hidden_states, residual) + + +def maybe_execute_sparse_ffn_begin(hidden_states: torch.Tensor, residual: torch.Tensor): + if not has_ucm_sparse(): + return hidden_states, residual + ucm_spare = get_ucm_sparse() + return ucm_spare.ffn_begin(hidden_states, residual) + + +def maybe_execute_sparse_ffn_finished( + hidden_states: torch.Tensor, residual: torch.Tensor +): + if not has_ucm_sparse(): + return hidden_states, residual + ucm_spare = get_ucm_sparse() + return ucm_spare.ffn_finished(hidden_states, residual)