diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 44f95c7686863..53fafbc4af91d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -10,8 +10,7 @@ import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) -from flashinfer.decode import (_get_range_buf, get_seq_lens, - trtllm_batch_decode_with_kv_cache) +from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache import vllm.envs as envs @@ -142,19 +141,10 @@ class FlashInferMetadata: # The number of entries in the last page of each request in # the paged kv cache, shape: [batch_size] (CPU for plan) paged_kv_last_page_len_cpu: torch.Tensor - # The number of query/output heads - num_qo_heads: int - # The number of key/value heads - num_kv_heads: int - # The dimension of the attention heads - head_dim: int - # Block size of vllm - page_size: int - # The data type of the paged kv cache - kv_data_type: torch.dtype # The data type of the query q_data_type: torch.dtype + seq_lens_cpu: torch.Tensor slot_mapping: torch.Tensor # For flashinfer trtllm batch decode @@ -185,10 +175,6 @@ class FlashInferMetadata: qo_indptr_gpu: Optional[torch.Tensor] = None paged_kv_indptr_gpu: Optional[torch.Tensor] = None - def __post_init__(self): - if self.head_dim is not None: - FlashInferBackend.validate_head_size(self.head_dim) - class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = \ @@ -201,13 +187,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.device = device self.vllm_config = vllm_config self.cache_config = vllm_config.cache_config + self.model_config = vllm_config.model_config self.kv_cache_spec = kv_cache_spec self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, + max_num_pages_per_req = cdiv(self.model_config.max_model_len, self.kv_cache_spec.block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req @@ -221,6 +208,29 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self._decode_cudagraph_max_bs = min( max_num_reqs, self.compilation_config.max_capture_size) + self.num_qo_heads = self.model_config.get_num_attention_heads( + self.vllm_config.parallel_config) + self.num_kv_heads = self.kv_cache_spec.num_kv_heads + self.head_dim = self.kv_cache_spec.head_size + FlashInferBackend.validate_head_size(self.head_dim) + self.page_size = self.kv_cache_spec.block_size + + self.enable_fusion = ( + self.compilation_config.pass_config.enable_attn_fusion) + self.q_data_type = self.model_config.dtype + self.cache_dtype = self.cache_config.cache_dtype + if self.cache_dtype.startswith("fp8"): + self.kv_cache_dtype = ( + FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.cache_dtype)) + # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled + if self.enable_fusion: + self.q_data_type = self.kv_cache_dtype + else: + self.kv_cache_dtype = self.kv_cache_spec.dtype + self.use_tensor_cores = (envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or + (self.num_qo_heads // self.num_kv_heads > 4)) + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers @@ -282,14 +292,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): decode_wrapper = self._decode_wrapper if decode_wrapper is None: - num_qo_heads = ( - self.vllm_config.model_config.get_num_attention_heads( - self.vllm_config.parallel_config)) - num_kv_heads = self.vllm_config.model_config.get_num_kv_heads( - self.vllm_config.parallel_config) - use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( - num_qo_heads // num_kv_heads > 4) - if use_cudagraph: paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] paged_kv_indices = self.paged_kv_indices @@ -306,7 +308,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_indptr_buffer=paged_kv_indptr, paged_kv_indices_buffer=paged_kv_indices, paged_kv_last_page_len_buffer=paged_kv_last_page_len, - use_tensor_cores=use_tensor_cores) + use_tensor_cores=self.use_tensor_cores) # save the decode wrapper if use_cudagraph: @@ -342,16 +344,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): attn_metadata.shared_kv_last_page_len_cpu, attn_metadata.paged_kv_last_page_len_cpu ], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, causal=True, sm_scale=self.global_hyperparameters.sm_scale, window_left=self.global_hyperparameters.window_left, logits_soft_cap=self.global_hyperparameters.logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, ) else: # Regular attention (common case). @@ -383,17 +385,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): attn_metadata.paged_kv_indices, attn_metadata. paged_kv_last_page_len_cpu[prefill_start:], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, causal=True, sm_scale=self.global_hyperparameters.sm_scale, window_left=self.global_hyperparameters.window_left, logits_soft_cap=self.global_hyperparameters. logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, ) else: attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device) @@ -435,18 +437,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.paged_kv_indptr_cpu[:num_input_tokens + 1], attn_metadata.paged_kv_indices, self.paged_kv_last_page_len_cpu[:num_input_tokens], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, + attn_metadata.seq_lens_cpu[:num_input_tokens], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", sm_scale=self.global_hyperparameters.sm_scale, window_left=self.global_hyperparameters.window_left, logits_soft_cap=self.global_hyperparameters. logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, ) def build(self, @@ -458,9 +461,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata) - page_size = self.kv_cache_spec.block_size + page_size = self.page_size max_q_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.seq_lens_cpu.max() + max_seq_len = common_attn_metadata.seq_lens_cpu.max().item() seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor @@ -495,7 +498,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): shared_kv_page_indices_cpu = None shared_kv_last_page_len_cpu = None - max_num_blocks = block_table_bounds_cpu.max() + max_num_blocks = block_table_bounds_cpu.max().item() block_table_bounds = block_table_bounds_cpu.to(self.device, non_blocking=True) mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) @@ -520,42 +523,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_last_page_len_cpu, out=self.paged_kv_last_page_len_cpu[:num_reqs]) - cache_dtype = self.cache_config.cache_dtype - if cache_dtype.startswith("fp8"): - kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - cache_dtype) - else: - kv_cache_dtype = self.kv_cache_spec.dtype - - config = self.vllm_config - num_qo_heads = config.model_config.get_num_attention_heads( - config.parallel_config) - num_kv_heads = self.kv_cache_spec.num_kv_heads - head_dim = self.kv_cache_spec.head_size - # Check if any layer uses sinks (requires TRTLLM attention) has_sinks = self.global_hyperparameters.has_sinks - # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled - q_dtype = config.model_config.dtype - enable_fusion = config.compilation_config.pass_config.enable_attn_fusion - if cache_dtype.startswith("fp8") and enable_fusion: - q_dtype = kv_cache_dtype - - prefill_use_trtllm = use_trtllm_attention(num_qo_heads, - num_kv_heads, + prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, + self.num_kv_heads, num_prefill_tokens, max_seq_len, - cache_dtype, - q_dtype, + self.cache_dtype, + self.q_data_type, is_prefill=True, has_sinks=has_sinks) - decode_use_trtllm = use_trtllm_attention(num_qo_heads, - num_kv_heads, + decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, + self.num_kv_heads, num_decode_tokens, max_seq_len, - cache_dtype, - q_dtype, + self.cache_dtype, + self.q_data_type, is_prefill=False, has_sinks=has_sinks) @@ -566,12 +550,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_indices=paged_kv_indices, paged_kv_last_page_len_cpu=self. paged_kv_last_page_len_cpu[:num_reqs], - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - page_size=page_size, - kv_data_type=kv_cache_dtype, - q_data_type=q_dtype, + q_data_type=self.q_data_type, + seq_lens_cpu=seq_lens_cpu, slot_mapping=common_attn_metadata.slot_mapping, max_q_len=max_q_len, max_seq_len=max_seq_len, @@ -910,6 +890,7 @@ def fast_plan_decode( indptr_cpu: torch.Tensor, indices: torch.Tensor, last_page_len_cpu: torch.Tensor, + seq_lens_cpu: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -987,9 +968,6 @@ def fast_plan_decode( kv_data_type = getattr(torch, kv_data_type) if isinstance( kv_data_type, str) else kv_data_type - if self.use_tensor_cores: - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") - if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime " @@ -1006,12 +984,8 @@ def fast_plan_decode( self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) - indptr_host = indptr_cpu - last_page_len_host = last_page_len_cpu - if self.use_tensor_cores: - kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, - page_size) + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") try: # Make sure we pass exactly 15 arguments for tensor core version @@ -1020,8 +994,8 @@ def fast_plan_decode( self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, qo_indptr_host, - indptr_host, - kv_lens_arr_host, + indptr_cpu, + seq_lens_cpu, batch_size, # total_num_rows batch_size, num_qo_heads, @@ -1041,7 +1015,7 @@ def fast_plan_decode( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, - indptr_host, + indptr_cpu, batch_size, num_qo_heads, num_kv_heads,