diff --git a/vllm/attention/layers/static_sink_attention.py b/vllm/attention/layers/static_sink_attention.py index beb9add10024b..2bf95943ee095 100644 --- a/vllm/attention/layers/static_sink_attention.py +++ b/vllm/attention/layers/static_sink_attention.py @@ -66,20 +66,13 @@ def create_static_sink_attention_backend( common_attn_metadata.seq_lens[:] = ( common_attn_metadata.seq_lens + self.sink_len ) + common_attn_metadata.seq_lens[ + common_attn_metadata.seq_lens == self.sink_len + ] = 0 common_attn_metadata.max_seq_len = ( common_attn_metadata.max_seq_len + self.sink_len ) - blk_table_tensor = common_attn_metadata.block_table_tensor - sink_block_table = self.sink_block_table[None, :].expand( - blk_table_tensor.shape[0], -1 - ) - blk_table_tensor_clone = blk_table_tensor.clone() - blk_table_tensor[:, self.num_sink_blocks :] = blk_table_tensor_clone[ - :, : -self.num_sink_blocks - ] - blk_table_tensor[:, : self.num_sink_blocks] = sink_block_table - return super().build(common_prefix_len, common_attn_metadata, fast_build) attn_backend = subclass_attention_backend( diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 14905b36754b4..fe9e7a9891941 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -801,6 +801,146 @@ class SinkFullAttentionManager(FullAttentionManager): num_sink_block = sink_len // self.block_size self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block) + def get_num_blocks_to_allocate( + self, + request_id: str, + num_tokens: int, + new_computed_blocks: Sequence[KVCacheBlock], + ) -> int: + """ + Get the number of blocks needed to be allocated for the request. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + new_computed_blocks: The new computed blocks just hitting the + prefix caching. + + Returns: + The number of blocks. + """ + + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = ( + num_required_blocks + - len(new_computed_blocks) + - len(self.req_to_blocks[request_id]) + ) + # Number of sink blocks is calculated into num_new_blocks + if len(self.req_to_blocks[request_id]) > 0: + num_new_blocks = num_new_blocks + len(self.sink_blocks) + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it will be changed from a free block + # to a computed block when the request is allocated, so we also count + # it as needed to be allocated. + num_evictable_computed_blocks = sum( + blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks + ) + return num_new_blocks + num_evictable_computed_blocks + + def save_new_computed_blocks( + self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] + ) -> None: + """ + Add the new computed blocks to the request. + + Args: + request_id: The request ID. + new_computed_blocks: The new computed blocks just hitting the + prefix cache. + """ + if request_id not in self.num_cached_block: + # A new request. + req_blocks = self.req_to_blocks[request_id] + assert len(req_blocks) == 0 + # Append both sink blocks and hitted prefix cache blocks + req_blocks.extend(self.sink_blocks + new_computed_blocks) + self.num_cached_block[request_id] = len(new_computed_blocks) + else: + # A running request. Should not have new computed blocks. + assert len(new_computed_blocks) == 0 + + def allocate_new_blocks( + self, request_id: str, num_tokens: int + ) -> list[KVCacheBlock]: + """ + Allocate new blocks for the request to give it at least `num_tokens` + token slots. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + + Returns: + The new allocated blocks. + """ + req_blocks = self.req_to_blocks[request_id] + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = num_required_blocks - len(req_blocks) + # For existing requests, number of sink blocks is calculated into + # num_new_blocks + if len(req_blocks) > 0: + num_new_blocks = num_new_blocks + len(self.sink_blocks) + if num_new_blocks <= 0: + return [] + else: + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) + # For new requests, allocate sink blocks + if len(req_blocks) == 0: + req_blocks.extend(self.sink_blocks + new_blocks) + else: + req_blocks.extend(new_blocks) + return new_blocks + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + """ + Cache the blocks for the request. + + Args: + request: The request. + num_tokens: The total number of tokens that need to be cached + (including tokens that are already cached). + """ + num_cached_blocks = self.num_cached_block.get(request.request_id, 0) + num_full_blocks = num_tokens // self.block_size + + if num_cached_blocks >= num_full_blocks: + return + + self.block_pool.cache_full_blocks( + request=request, + # Do not cache sink blocks + blocks=self.req_to_blocks[request.request_id][len(self.sink_blocks) :], + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks, + block_size=self.block_size, + kv_cache_group_id=self.kv_cache_group_id, + ) + + self.num_cached_block[request.request_id] = num_full_blocks + + def free(self, request_id: str) -> None: + """ + Free the blocks for the request. + + Args: + request_id: The request ID. + """ + # Default to [] in case a request is freed (aborted) before alloc. + req_blocks = self.req_to_blocks.pop(request_id, []) + # Do not free sink blocks + if len(req_blocks) > 0: + req_blocks = req_blocks[len(self.sink_blocks) :] + + # Free blocks in reverse order so that the tail blocks are + # freed first. + ordered_blocks = reversed(req_blocks) + + self.block_pool.free_blocks(ordered_blocks) + self.num_cached_block.pop(request_id, None) + spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index dd61d2150a797..5703f80db0754 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -23,6 +23,7 @@ class BlockTable: device: torch.device, kernel_block_size: int, cp_kv_cache_interleave_size: int, + sink_len: int = 0, ): """ Args: @@ -63,6 +64,8 @@ class BlockTable: self.use_hybrid_blocks = True self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block + self.sink_block_len = sink_len // self.block_size + self.max_num_blocks_per_req = self.max_num_blocks_per_req + self.sink_block_len self.block_table = self._make_buffer( self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32 @@ -151,7 +154,7 @@ class BlockTable: block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // virtual_block_size - ) + ) + self.sink_block_len block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local @@ -177,9 +180,10 @@ class BlockTable: mask, slot_mapping, -1 ) else: + # When self.sink_block_len > 0, we need to shift block table indices block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // self.block_size - ) + ) + self.sink_block_len block_numbers = self.block_table.np.ravel()[block_table_indices] block_offsets = positions % self.block_size @@ -293,7 +297,7 @@ class MultiGroupBlockTable: block_size, max_num_reqs, max( - cdiv(max_model_len + sink_len, block_size * total_cp_world_size), + cdiv(max_model_len, block_size * total_cp_world_size), 1 + num_speculative_tokens, ), max_num_batched_tokens, @@ -301,6 +305,7 @@ class MultiGroupBlockTable: device, kernel_block_size, cp_kv_cache_interleave_size, + sink_len=sink_len, ) for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) ] diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 09a5bd885309d..6845652688430 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -101,7 +101,10 @@ def _reshape_kv_cache( num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes attn_backend = attn_backends[layer_name] - if hasattr(kv_cache_spec, "head_size_v"): + if ( + getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) + != kv_cache_spec.head_size + ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} else: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5f81f9ba23ddc..0bf773c84596d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5206,7 +5206,10 @@ class GPUModelRunner( ) kernel_num_blocks = num_blocks * num_blocks_per_kv_block - if hasattr(kv_cache_spec, "head_size_v"): + if ( + getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) + != kv_cache_spec.head_size + ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} else: diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 70e99db9e9762..a31013502c377 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -190,7 +190,10 @@ class KVConnectorModelRunnerMixin: return False attn_backend = attn_group.backend - if hasattr(kv_cache_spec, "head_size_v"): + if ( + getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) + != kv_cache_spec.head_size + ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} else: @@ -265,7 +268,10 @@ class KVConnectorModelRunnerMixin: kernel_num_blocks = num_blocks * num_blocks_per_kv_block attn_backend = attn_group.backend - if hasattr(kv_cache_spec, "head_size_v"): + if ( + getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) + != kv_cache_spec.head_size + ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} else: