From cf58a620990c3bd1ae45cccc90575b77351f3637 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Mon, 22 Dec 2025 22:48:14 +0800 Subject: [PATCH] Refactor code, move static sink logics to builder Signed-off-by: yuantao <2422264527@qq.com> --- .../attention/layers/static_sink_attention.py | 47 +++++- vllm/v1/core/single_type_kv_cache_manager.py | 141 ------------------ vllm/v1/worker/block_table.py | 10 +- vllm/v1/worker/gpu/attn_utils.py | 14 +- vllm/v1/worker/gpu_input_batch.py | 2 - vllm/v1/worker/gpu_model_runner.py | 21 +-- .../worker/kv_connector_model_runner_mixin.py | 22 --- 7 files changed, 48 insertions(+), 209 deletions(-) diff --git a/vllm/attention/layers/static_sink_attention.py b/vllm/attention/layers/static_sink_attention.py index 2bf95943ee095..e5ed16ec14932 100644 --- a/vllm/attention/layers/static_sink_attention.py +++ b/vllm/attention/layers/static_sink_attention.py @@ -17,6 +17,8 @@ from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig, VllmConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, @@ -48,9 +50,23 @@ def create_static_sink_attention_backend( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) + model_config = vllm_config.model_config + scheduler_config = vllm_config.scheduler_config self.sink_len = sink_len + self.block_size = vllm_config.cache_config.block_size self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size - self.sink_block_table = torch.arange( + self.max_num_blocks = cdiv( + model_config.max_model_len, vllm_config.cache_config.block_size + ) + self.block_table_with_sink = torch.zeros( + ( + scheduler_config.max_num_seqs, + self.max_num_blocks + self.num_sink_blocks, + ), + device=device, + dtype=torch.int32, + ) + self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange( 1, self.num_sink_blocks + 1, device=device, @@ -72,6 +88,14 @@ def create_static_sink_attention_backend( common_attn_metadata.max_seq_len = ( common_attn_metadata.max_seq_len + self.sink_len ) + max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size) + num_reqs = common_attn_metadata.num_reqs + self.block_table_with_sink[ + :num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks + ] = common_attn_metadata.block_table_tensor[:, :max_num_blocks] + common_attn_metadata.block_table_tensor = self.block_table_with_sink[ + :num_reqs + ] return super().build(common_prefix_len, common_attn_metadata, fast_build) @@ -84,7 +108,8 @@ def create_static_sink_attention_backend( return attn_backend -class StaticSinkAttention(Attention): +@CustomOp.register("static_sink_attention") +class StaticSinkAttention(Attention, CustomOp): """ Attention with static sink tokens """ @@ -118,7 +143,8 @@ class StaticSinkAttention(Attention): underlying_attn_backend, sink_len=sink_len, ) - super().__init__( + Attention.__init__( + self=self, num_heads=num_heads, head_size=head_size, scale=scale, @@ -126,6 +152,7 @@ class StaticSinkAttention(Attention): attn_backend=attn_backend, **kwargs, ) + CustomOp.__init__(self) self.sink_len = sink_len self.block_size = block_size @@ -137,7 +164,7 @@ class StaticSinkAttention(Attention): self.sink_key = sink_key self.sink_value = sink_value - def forward( + def forward_native( self, query: torch.Tensor, key: torch.Tensor, @@ -154,6 +181,18 @@ class StaticSinkAttention(Attention): return super().forward(query, key, value, output_shape) + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output_shape: torch.Size | None = None, + ) -> torch.Tensor: + return self.forward_native(query, key, value, output_shape) + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + def populate_sink_kv(self, self_kv_cache): sink_kv_slot_mapping = torch.arange( self.block_size, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e6c7150500cc3..14905b36754b4 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -801,147 +801,6 @@ class SinkFullAttentionManager(FullAttentionManager): num_sink_block = sink_len // self.block_size self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block) - def get_num_blocks_to_allocate( - self, - request_id: str, - num_tokens: int, - new_computed_blocks: Sequence[KVCacheBlock], - ) -> int: - """ - Get the number of blocks needed to be allocated for the request. - - Args: - request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including - tokens that are already allocated). - new_computed_blocks: The new computed blocks just hitting the - prefix caching. - - Returns: - The number of blocks. - """ - - num_required_blocks = cdiv(num_tokens, self.block_size) - num_new_blocks = ( - num_required_blocks - - len(new_computed_blocks) - - len(self.req_to_blocks[request_id]) - ) - # Number of sink blocks is calculated into num_new_blocks - if len(self.req_to_blocks[request_id]) > 0: - num_new_blocks = num_new_blocks + len(self.sink_blocks) - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it will be changed from a free block - # to a computed block when the request is allocated, so we also count - # it as needed to be allocated. - num_evictable_computed_blocks = sum( - blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks - ) - return num_new_blocks + num_evictable_computed_blocks - - def save_new_computed_blocks( - self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] - ) -> None: - """ - Add the new computed blocks to the request. - - Args: - request_id: The request ID. - new_computed_blocks: The new computed blocks just hitting the - prefix cache. - """ - if request_id not in self.num_cached_block: - # A new request. - req_blocks = self.req_to_blocks[request_id] - assert len(req_blocks) == 0 - # Append both sink blocks and hitted prefix cache blocks - req_blocks.extend(self.sink_blocks) - req_blocks.extend(new_computed_blocks) - self.num_cached_block[request_id] = len(new_computed_blocks) - else: - # A running request. Should not have new computed blocks. - assert len(new_computed_blocks) == 0 - - def allocate_new_blocks( - self, request_id: str, num_tokens: int - ) -> list[KVCacheBlock]: - """ - Allocate new blocks for the request to give it at least `num_tokens` - token slots. - - Args: - request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including - tokens that are already allocated). - - Returns: - The new allocated blocks. - """ - req_blocks = self.req_to_blocks[request_id] - num_required_blocks = cdiv(num_tokens, self.block_size) - num_new_blocks = num_required_blocks - len(req_blocks) - # For existing requests, number of sink blocks is calculated into - # num_new_blocks - if len(req_blocks) > 0: - num_new_blocks = num_new_blocks + len(self.sink_blocks) - if num_new_blocks <= 0: - return [] - else: - new_blocks = self.block_pool.get_new_blocks(num_new_blocks) - # For new requests, allocate sink blocks - if len(req_blocks) == 0: - req_blocks.extend(self.sink_blocks + new_blocks) - else: - req_blocks.extend(new_blocks) - return new_blocks - - def cache_blocks(self, request: Request, num_tokens: int) -> None: - """ - Cache the blocks for the request. - - Args: - request: The request. - num_tokens: The total number of tokens that need to be cached - (including tokens that are already cached). - """ - num_cached_blocks = self.num_cached_block.get(request.request_id, 0) - num_full_blocks = num_tokens // self.block_size - - if num_cached_blocks >= num_full_blocks: - return - - self.block_pool.cache_full_blocks( - request=request, - # Do not cache sink blocks - blocks=self.req_to_blocks[request.request_id][len(self.sink_blocks) :], - num_cached_blocks=num_cached_blocks, - num_full_blocks=num_full_blocks, - block_size=self.block_size, - kv_cache_group_id=self.kv_cache_group_id, - ) - - self.num_cached_block[request.request_id] = num_full_blocks - - def free(self, request_id: str) -> None: - """ - Free the blocks for the request. - - Args: - request_id: The request ID. - """ - # Default to [] in case a request is freed (aborted) before alloc. - req_blocks = self.req_to_blocks.pop(request_id, []) - # Do not free sink blocks - if len(req_blocks) > 0: - req_blocks = req_blocks[len(self.sink_blocks) :] - - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(req_blocks) - - self.block_pool.free_blocks(ordered_blocks) - self.num_cached_block.pop(request_id, None) - spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 5703f80db0754..37ec0fb97e06b 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -23,7 +23,6 @@ class BlockTable: device: torch.device, kernel_block_size: int, cp_kv_cache_interleave_size: int, - sink_len: int = 0, ): """ Args: @@ -64,8 +63,6 @@ class BlockTable: self.use_hybrid_blocks = True self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block - self.sink_block_len = sink_len // self.block_size - self.max_num_blocks_per_req = self.max_num_blocks_per_req + self.sink_block_len self.block_table = self._make_buffer( self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32 @@ -154,7 +151,7 @@ class BlockTable: block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // virtual_block_size - ) + self.sink_block_len + ) block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local @@ -180,10 +177,9 @@ class BlockTable: mask, slot_mapping, -1 ) else: - # When self.sink_block_len > 0, we need to shift block table indices block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // self.block_size - ) + self.sink_block_len + ) block_numbers = self.block_table.np.ravel()[block_table_indices] block_offsets = positions % self.block_size @@ -267,7 +263,6 @@ class MultiGroupBlockTable: kernel_block_sizes: list[int], num_speculative_tokens: int = 0, cp_kv_cache_interleave_size: int = 1, - sink_len: int = 0, ) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, @@ -305,7 +300,6 @@ class MultiGroupBlockTable: device, kernel_block_size, cp_kv_cache_interleave_size, - sink_len=sink_len, ) for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) ] diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 604edf3d7ae29..6386f1a08b446 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -101,28 +101,16 @@ def _reshape_kv_cache( num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes attn_backend = attn_backends[layer_name] - if ( - hasattr(kv_cache_spec, "head_size_v") - and kv_cache_spec.head_size_v != kv_cache_spec.head_size - ): - kwargs = {"head_size_v": kv_cache_spec.head_size_v} - stride_kwargs = {"diff_kv": True} - else: - kwargs = {} - stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, - **kwargs, ) # FIXME(woosuk): Add kv_cache_stride_order to all attention backends. try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - **stride_kwargs - ) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(kv_cache_shape))) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c567fc7219c3b..ead7a3619dea5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -96,7 +96,6 @@ class InputBatch: is_pooling_model: bool = False, num_speculative_tokens: int = 0, cp_kv_cache_interleave_size: int = 1, - sink_len: int = 0, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -151,7 +150,6 @@ class InputBatch: kernel_block_sizes=kernel_block_sizes, num_speculative_tokens=num_speculative_tokens, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, - sink_len=sink_len, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e02e7ee2d4c26..b1e4d04717768 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -332,10 +332,6 @@ class GPUModelRunner( self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.inputs_embeds_size = model_config.get_inputs_embeds_size() self.attention_chunk_size = model_config.attention_chunk_size - self.sink_len = getattr( - self.vllm_config.model_config.hf_config, "param_sink_number", 0 - ) - assert self.sink_len % self.cache_config.block_size == 0 # Only relevant for models using ALiBi (e.g, MPT) self.use_alibi = model_config.uses_alibi @@ -459,7 +455,6 @@ class GPUModelRunner( logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, - sink_len=self.sink_len, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -5079,7 +5074,7 @@ class GPUModelRunner( logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, num_speculative_tokens=self.num_spec_tokens, - sink_len=self.sink_len, + # sink_len=self.sink_len, ) def _allocate_kv_cache_tensors( @@ -5206,28 +5201,16 @@ class GPUModelRunner( ) kernel_num_blocks = num_blocks * num_blocks_per_kv_block - if ( - hasattr(kv_cache_spec, "head_size_v") - and kv_cache_spec.head_size_v != kv_cache_spec.head_size - ): - kwargs = {"head_size_v": kv_cache_spec.head_size_v} - stride_kwargs = {"diff_kv": True} - else: - kwargs = {} - stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( kernel_num_blocks, kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=self.cache_config.cache_dtype, - **kwargs, ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - **stride_kwargs - ) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(kv_cache_shape))) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index a45b83dc788f1..f266a1386e10d 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -190,28 +190,17 @@ class KVConnectorModelRunnerMixin: return False attn_backend = attn_group.backend - if ( - hasattr(kv_cache_spec, "head_size_v") - and kv_cache_spec.head_size_v != kv_cache_spec.head_size - ): - kwargs = {"head_size_v": kv_cache_spec.head_size_v} - stride_kwargs = {"diff_kv": True} - else: - kwargs = {} - stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( 1234, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=cache_dtype, - **kwargs, ) try: kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( include_num_layers_dimension=True, - **stride_kwargs, ) except (AttributeError, NotImplementedError): return False @@ -268,22 +257,12 @@ class KVConnectorModelRunnerMixin: kernel_num_blocks = num_blocks * num_blocks_per_kv_block attn_backend = attn_group.backend - if ( - hasattr(kv_cache_spec, "head_size_v") - and kv_cache_spec.head_size_v != kv_cache_spec.head_size - ): - kwargs = {"head_size_v": kv_cache_spec.head_size_v} - stride_kwargs = {"diff_kv": True} - else: - kwargs = {} - stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( kernel_num_blocks, kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=cache_dtype, - **kwargs, ) # prepend a num_layers dimension into the shape @@ -292,7 +271,6 @@ class KVConnectorModelRunnerMixin: try: kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( include_num_layers_dimension=True, - **stride_kwargs, ) assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError):