diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 8b804408ea41f..a82f1bb20fba3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -27,7 +27,7 @@ from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, - make_local_attention_virtual_batches, slice_query_start_locs) + make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -172,27 +172,28 @@ class FlashAttentionMetadataBuilder( # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None - def build_slice( - self, - req_slice: slice, - token_slice: slice, - max_query_len: int, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, + def build( + self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata ) -> FlashAttentionMetadata: - num_reqs = req_slice.stop - req_slice.start - num_tokens = token_slice.stop - token_slice.start + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(self.runner.seq_lens_np[req_slice].max()) - query_start_loc = slice_query_start_locs( - common_attn_metadata.query_start_loc, req_slice) - seq_lens = common_attn_metadata.seq_lens[req_slice] + max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[req_slice] + block_table_tensor = block_table.get_device_tensor()[:num_reqs] - block_table.slot_mapping[token_slice].copy_( - block_table.slot_mapping_cpu[token_slice], non_blocking=True) - slot_mapping = block_table.slot_mapping[token_slice] + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], + non_blocking=True) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + + slot_mapping = block_table.slot_mapping[:num_actual_tokens] if self.aot_sliding_window is None: self.aot_sliding_window = (-1, -1) @@ -234,8 +235,8 @@ class FlashAttentionMetadataBuilder( seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ virt_block_table_tensor = make_local_attention_virtual_batches( self.runner.attention_chunk_size, - query_start_loc, - seq_lens, + self.runner.query_start_loc_np[:num_reqs + 1], + self.runner.seq_lens_np[:num_reqs], block_table_tensor, self.block_size, ) @@ -265,20 +266,20 @@ class FlashAttentionMetadataBuilder( use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_tokens], + cu_prefix_query_lens = torch.tensor([0,num_actual_tokens], dtype=torch.int32, device=self.runner.device) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32, device=self.runner.device) - suffix_kv_lens = (self.runner.seq_lens_np[req_slice] - + suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - common_prefix_len) suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( self.runner.device) prefix_scheduler_metadata = schedule( batch_size=1, cu_query_lens=cu_prefix_query_lens, - max_query_len=num_tokens, + max_query_len=num_actual_tokens, seqlens=prefix_kv_lens, max_seq_len=common_prefix_len, causal=False) @@ -302,7 +303,7 @@ class FlashAttentionMetadataBuilder( causal=True) attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_tokens, + num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, max_seq_len=max_seq_len, @@ -320,28 +321,13 @@ class FlashAttentionMetadataBuilder( ) return attn_metadata - def build( - self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata - ) -> FlashAttentionMetadata: - num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len - return self.build_slice( - req_slice=slice(0, num_reqs), - token_slice=slice(0, num_actual_tokens), - max_query_len=max_query_len, - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - ) - def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: # Full CUDA Graph always supported (FA2 support checked separately) return True def use_cascade_attention(self, *args, **kwargs) -> bool: - return False #use_cascade_attention(*args, **kwargs) + return use_cascade_attention(*args, **kwargs) class FlashAttentionImpl(AttentionImpl):