diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 2740a6916fd97..623ae892ecdaf 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -16,6 +16,7 @@ from flashinfer import ( from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor +from typing_extensions import override from vllm import envs from vllm.attention.backends.abstract import ( @@ -59,6 +60,7 @@ from vllm.v1.attention.backends.utils import ( split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.utils import CpuGpuBuffer FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 @@ -181,7 +183,6 @@ class BatchDCPPrefillWrapper: paged_kv_indptr_cpu: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len_cpu: torch.Tensor, - prefill_start: int, page_size: int, num_qo_heads: int, dcp_world_size: int, @@ -200,7 +201,7 @@ class BatchDCPPrefillWrapper: qo_indptr_cpu, paged_kv_indptr_cpu, paged_kv_indices, - paged_kv_last_page_len_cpu[prefill_start:], + paged_kv_last_page_len_cpu, num_qo_heads * dcp_world_size, num_kv_heads, head_dim, @@ -380,40 +381,103 @@ class FlashInferBackend(AttentionBackend): @dataclass -class FlashInferMetadata: - num_actual_tokens: int # Number of tokens excluding padding. +class FIPrefill: + """Metadata for the native FlashInfer prefill pathway (non-TRTLLM).""" - # The data type of the query - q_data_type: torch.dtype + wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper + + +@dataclass +class FIDecode: + """Metadata for the native FlashInfer decode pathway (non-TRTLLM).""" + + wrapper: BatchDecodeWithPagedKVCacheWrapper + + +@dataclass +class TRTLLMPrefill: + """Metadata for the TRTLLM prefill pathway.""" + + block_tables: torch.Tensor + """ + The slice of the block table tensor corresponding *only* to prefill requests. + Shape: [num_prefills, max_num_blocks_per_seq] + """ + + seq_lens: torch.Tensor + """ + The slice of the sequence lengths tensor corresponding *only* to prefill requests. + Shape: [num_prefills] + """ + + cum_seq_lens_q: torch.Tensor + cum_seq_lens_kv: torch.Tensor + + max_q_len: int + """ + The maximum query length *among prefill requests*. + """ + + max_seq_len: int + """The maximum sequence length for KV Cache.""" + + +@dataclass +class TRTLLMDecode: + """Metadata for the TRTLLM decode pathway.""" + + block_tables: torch.Tensor + """ + The slice of the block table tensor corresponding *only* to decode requests. + Shape: [num_decodes, max_num_blocks_per_seq] + """ + + seq_lens: torch.Tensor + """ + The slice of the sequence lengths tensor corresponding *only* to decode requests. + Shape: [num_decodes] + """ + + max_seq_len: int + """The maximum sequence length for KV Cache.""" + + +@dataclass +class FlashInferMetadata: + num_actual_tokens: int + """Total number of tokens in the batch (excluding padding).""" slot_mapping: torch.Tensor + """Tensor for writing K/V to the cache. Shape: [num_actual_tokens]""" - # For flashinfer trtllm batch decode - max_q_len: int - max_q_len_prefill: int - max_seq_len: int - seq_lens: torch.Tensor - block_table_tensor: torch.Tensor - prefill_use_trtllm: bool - decode_use_trtllm: bool + q_data_type: torch.dtype - # For handling prefill decode split num_decodes: int num_decode_tokens: int num_prefills: int num_prefill_tokens: int - # For cascade attention (CPU for planning). + prefill: FIPrefill | TRTLLMPrefill | None + """ + Holds the metadata for the prefill portion of the batch. + Will be `None` if `num_prefill_tokens == 0`. + """ + + decode: FIDecode | TRTLLMDecode | None + """ + Holds the metadata for the decode portion of the batch. + Will be `None` if `num_decode_tokens == 0`. + """ + + # --- Special Case: Cascade Attention --- + use_cascade: bool + """ + If True, the entire batch is a cascade attention call, and the + `prefill` and `decode` fields will both be None. + """ - prefill_wrapper: ( - BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None - ) = None - decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None - cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None - - qo_indptr_gpu: torch.Tensor | None = None - paged_kv_indptr_gpu: torch.Tensor | None = None + cascade_wrapper: MultiLevelCascadeAttentionWrapper | None class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): @@ -482,6 +546,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.dcp_world_size = 1 self.dcp_rank = 0 self.dcp_kv_cache_interleave_size = 1 + self.use_dcp = self.dcp_world_size > 1 self.num_qo_heads = self.model_config.get_num_attention_heads( self.vllm_config.parallel_config @@ -535,34 +600,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): "sinks, please use trtllm on blackwell or flash attention on " "earlier GPUs." ) - # Preparing persistent buffers (device-side) - self.paged_kv_indptr = torch.zeros( - max_num_reqs + 1, dtype=torch.int32, device=self.device - ) - self.paged_kv_indices = torch.zeros( - max_num_pages, # max num pages possible - dtype=torch.int32, - device=self.device, - ) - self.paged_kv_last_page_len = torch.zeros( - max_num_reqs, dtype=torch.int32, device=self.device - ) - # host-side buffer - pin_memory = is_pin_memory_available() - self.paged_kv_indptr_cpu = torch.zeros( - max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory - ) - self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() - self.paged_kv_indptr_buffer = torch.zeros_like( - self.paged_kv_indptr_cpu, pin_memory=pin_memory - ) - self.paged_kv_indices_cpu = torch.zeros( - max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory - ) - self.paged_kv_last_page_len_cpu = torch.zeros( - max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory - ) - self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() + # Preparing persistent buffers + self.pin_memory = is_pin_memory_available() + self.paged_kv_indptr = self._make_buffer(max_num_reqs + 1) + self.paged_kv_indptr_cpu_buffer = torch.zeros_like( + self.paged_kv_indptr.cpu, pin_memory=self.pin_memory + ) # Extra buffer for mutable paged_kv_indptr.cpu in cuda graph mode + self.paged_kv_indices = self._make_buffer(max_num_pages) + self.paged_kv_last_page_len = self._make_buffer(max_num_reqs) if self.head_dim == 256 and current_platform.is_device_capability_family(100): # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that @@ -573,6 +618,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): "passing --block-size 32 or --block-size 64." ) + def _make_buffer( + self, *size: int | torch.SymInt, dtype: torch.dtype = torch.int32 + ) -> CpuGpuBuffer: + return CpuGpuBuffer( + *size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=True, + ) + + @override # type: ignore[misc] @classmethod def get_cudagraph_support( cls: type["FlashInferMetadataBuilder"], @@ -607,7 +664,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self, ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper: if self._prefill_wrapper is None: - if self.dcp_world_size > 1: + if self.use_dcp: self._prefill_wrapper = BatchDCPPrefillWrapper( workspace_buffer=self._get_workspace_buffer(), ) @@ -626,9 +683,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): if decode_wrapper is None: if use_cudagraph: - paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1] - paged_kv_indices = self.paged_kv_indices - paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size] + paged_kv_indptr = self.paged_kv_indptr.gpu[: batch_size + 1] + paged_kv_indices = self.paged_kv_indices.gpu + paged_kv_last_page_len = self.paged_kv_last_page_len.gpu[:batch_size] else: paged_kv_indptr = None paged_kv_indices = None @@ -661,6 +718,60 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) return self._cascade_wrapper + def _compute_flashinfer_kv_metadata( + self, + num_blocks_np: np.ndarray, + seq_lens_np: np.ndarray, + block_table_tensor: torch.Tensor, + num_reqs: int, + page_size: int, + ) -> torch.Tensor: + """ + Compute paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len for FlashInfer + attention. + + Results are stored in self.paged_kv_indptr, + self.paged_kv_indices, self.paged_kv_last_page_len buffers. + + Returns paged_kv_indices, a GPU tensor with shape [num_actual_pages]. + """ + # write self.paged_kv_indptr_cpu inplace (0-index is always 0) + np.cumsum( + num_blocks_np, + dtype=np.int32, + out=self.paged_kv_indptr.np[1 : num_reqs + 1], + ) + # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified + # after this line (e.g., for cuda graphs), we need to copy the data to + # self.paged_kv_indptr_buffer to avoid race condition. + self.paged_kv_indptr_cpu_buffer[: num_reqs + 1] = self.paged_kv_indptr.cpu[ + : num_reqs + 1 + ] + paged_kv_indptr = self.paged_kv_indptr.gpu[: num_reqs + 1] + paged_kv_indptr.copy_( + self.paged_kv_indptr_cpu_buffer[: num_reqs + 1], non_blocking=True + ) + + # write self.paged_kv_indices inplace + num_actual_pages = self.paged_kv_indptr.np[num_reqs] + paged_kv_indices = self.paged_kv_indices.gpu[:num_actual_pages] + _copy_page_indices_kernel[(num_reqs,)]( + paged_kv_indices, + block_table_tensor, + block_table_tensor.stride(0), + paged_kv_indptr, + BLOCK_SIZE=1024, + ) + + # write self.paged_kv_last_page_len_cpu inplace + paged_kv_last_page_len_np = seq_lens_np % page_size + self.paged_kv_last_page_len.np[:num_reqs] = np.where( + (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0), + page_size, + paged_kv_last_page_len_np, + ) + return paged_kv_indices + def build( self, common_prefix_len: int, @@ -678,98 +789,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) page_size = self.page_size - max_q_len = common_attn_metadata.max_query_len max_seq_len = common_attn_metadata.max_seq_len seq_lens = common_attn_metadata.seq_lens - seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor + qo_indptr = common_attn_metadata.query_start_loc qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu - if self.dcp_world_size > 1: - if num_prefills > 0: - qo_indptr_prefill_cpu = ( - qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes] - ) - query_lens_prefill_cpu = ( - qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1] - ) - seq_lens_cpu[num_decodes:] = ( - seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu - ) - - seq_lens_cpu = get_dcp_local_seq_lens( - seq_lens_cpu, - self.dcp_world_size, - self.dcp_rank, - self.dcp_kv_cache_interleave_size, - ) - - seq_lens_np = seq_lens_cpu.numpy() - num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size - + # Step 1: Decide which dispatch modes to use: + # - Cascade attention (distinct mode) + # - Prefill (FI native or TRTLLM) + # - Decode (FI native or TRTLLM) use_cascade = common_prefix_len > 0 - if use_cascade: - # Grab the blocks of the shared prefix from the first request. - assert common_prefix_len % page_size == 0 - num_common_kv_blocks = common_prefix_len // page_size - - # Create CPU versions directly for cascade (no GPU versions needed) - shared_qo_indptr_cpu = torch.tensor( - [0, num_actual_tokens], dtype=torch.int32, device="cpu" - ) - shared_kv_page_indptr_cpu = torch.tensor( - [0, num_common_kv_blocks], dtype=torch.int32, device="cpu" - ) - shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks] - shared_kv_last_page_len_cpu = torch.tensor( - [page_size], dtype=torch.int32, device="cpu" - ) - - # Remove the blocks of the shared prefix from all requests. - block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] - num_blocks_np -= num_common_kv_blocks - else: - shared_qo_indptr_cpu = None - shared_kv_page_indptr_cpu = None - shared_kv_page_indices_cpu = None - shared_kv_last_page_len_cpu = None - - # write self.paged_kv_indptr_cpu inplace (0-index is always 0) - np.cumsum( - num_blocks_np, - dtype=np.int32, - out=self.paged_kv_indptr_np[1 : num_reqs + 1], - ) - # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified - # after this line (e.g., for cuda graphs), we need to copy the data to - # self.paged_kv_indptr_buffer to avoid race condition. - self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[ - : num_reqs + 1 - ] - paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1] - paged_kv_indptr.copy_( - self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True - ) - - # write self.paged_kv_indices inplace - num_actual_pages = self.paged_kv_indptr_np[num_reqs] - paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - _copy_page_indices_kernel[(num_reqs,)]( - paged_kv_indices, - block_table_tensor, - block_table_tensor.stride(0), - paged_kv_indptr, - BLOCK_SIZE=1024, - ) - - # write self.paged_kv_last_page_len_cpu inplace - paged_kv_last_page_len_np = seq_lens_np % page_size - self.paged_kv_last_page_len_np[:num_reqs] = np.where( - (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0), - page_size, - paged_kv_last_page_len_np, - ) - uses_spec_reorder = self.reorder_batch_threshold > 1 prefill_use_trtllm = use_trtllm_attention( self.num_qo_heads, @@ -788,7 +818,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.use_trtllm_decode_attention and self.dcp_world_size <= 1 ) - if not (prefill_use_trtllm and decode_use_trtllm): + all_uses_trtllm = (num_prefills == 0 or prefill_use_trtllm) and ( + num_decodes == 0 or decode_use_trtllm + ) + is_only_trtllm_decode = num_prefills == 0 and ( + num_decodes > 0 and decode_use_trtllm + ) + + if not all_uses_trtllm: if self.has_sinks: raise NotImplementedError( "FlashInfer backend currently does not support attention " @@ -813,28 +850,102 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # fall back to model dtype. self.q_data_type = self.model_config.dtype + # Step 2: Initialize the output metadata + # Leave prefill/decode/cascade_wrapper empty, to be populated + # case by case depending on the batch contents and backend selection. attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, - q_data_type=self.q_data_type, slot_mapping=common_attn_metadata.slot_mapping, - max_q_len=max_q_len, - max_q_len_prefill=max_q_len, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table_tensor=block_table_tensor, - prefill_use_trtllm=prefill_use_trtllm, - decode_use_trtllm=decode_use_trtllm, + q_data_type=self.q_data_type, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, + prefill=None, + decode=None, + cascade_wrapper=None, ) - paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] - paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] + # Guard access to seq_lens_cpu, which may not always be needed + # and can be expensive to retrieve in async mode. + needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode + seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None + seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None + num_blocks_np = ( + (seq_lens_np + (page_size - 1)) // page_size + if seq_lens_np is not None + else None + ) + + # Adjust seq_lens_cpu for DCP + if self.use_dcp: + assert seq_lens_cpu is not None + if num_prefills > 0: + qo_indptr_prefill_cpu = ( + qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes] + ) + query_lens_prefill_cpu = ( + qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1] + ) + seq_lens_cpu[num_decodes:] = ( + seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu + ) + + seq_lens_cpu = get_dcp_local_seq_lens( + seq_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_kv_cache_interleave_size, + ) + + # Adjust num_block_np for cascade attention + if use_cascade: + assert num_blocks_np is not None + assert common_prefix_len % page_size == 0 + num_common_kv_blocks = common_prefix_len // page_size + num_blocks_np -= num_common_kv_blocks + + # Compute paged_kv_indices if necessary + needs_paged_kv_indices = use_cascade or not is_only_trtllm_decode + if needs_paged_kv_indices: + assert num_blocks_np is not None + assert seq_lens_np is not None + paged_kv_indices = self._compute_flashinfer_kv_metadata( + num_blocks_np, + seq_lens_np, + block_table_tensor, + num_reqs, + page_size, + ) + else: + paged_kv_indices = None + + # Early-out for cascade attention + if use_cascade: + # Grab the blocks of the shared prefix from the first request. + num_common_kv_blocks = common_prefix_len // page_size + + # Create CPU versions directly for cascade (no GPU versions needed) + shared_qo_indptr_cpu = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indptr_cpu = torch.tensor( + [0, num_common_kv_blocks], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks] + shared_kv_last_page_len_cpu = torch.tensor( + [page_size], dtype=torch.int32, device="cpu" + ) + + # Remove the blocks of the shared prefix from all requests. + block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] + num_blocks_np -= num_common_kv_blocks + + assert paged_kv_indices is not None + paged_kv_indptr_cpu = self.paged_kv_indptr.cpu[: 1 + num_reqs] + paged_kv_last_page_len_cpu = self.paged_kv_last_page_len.cpu[:num_reqs] - if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( [shared_qo_indptr_cpu, qo_indptr_cpu], @@ -852,91 +963,107 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, ) - else: - # Regular attention (common case). - # Decodes are at the front and prefills are at the back. - num_prefills = attn_metadata.num_prefills - num_decodes = attn_metadata.num_decodes - if num_prefills > 0: - # Decodes are first so prefills start after the last decode - prefill_start = num_decodes - attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 - assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 - assert ( - paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills - ) - # Since prefill_wrapper.run() will be called with - # query[num_decode_tokens:] we need to adjust the qo_indptr - # to be relative to the start of the prefill queries. - qo_indptr_cpu = ( - qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] - ) - paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] + return attn_metadata - # Recompute max_q_len for the slice of requests we are using - # for prefills. This can be different from max_q_len when - # we have a non-uniform batch with some short decodes offloaded - # to the prefill pathway - query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] - attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) + # Step 3: Handle prefill and decode pathways case by case + ## PREFILL PATHWAY + if num_prefills > 0: + # Slices for shared prefill metadata + prefill_start = num_decodes + qo_indptr_prefill_cpu = ( + qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] + ) + assert qo_indptr_prefill_cpu.shape[0] == num_prefills + 1 - if not attn_metadata.prefill_use_trtllm: - if self.dcp_world_size > 1: - assert isinstance( - attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper - ) - attn_metadata.prefill_wrapper.plan( - qo_indptr_cpu=qo_indptr_cpu, - paged_kv_indptr_cpu=paged_kv_indptr_cpu, - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu, - prefill_start=prefill_start, - page_size=self.page_size, - num_qo_heads=self.num_qo_heads, - dcp_world_size=self.dcp_world_size, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - sm_scale=self.sm_scale, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_cache_dtype=self.kv_cache_dtype, - prefill_fixed_split_size=self.prefill_fixed_split_size, - disable_split_kv=self.disable_split_kv, - ) - else: - assert isinstance( - attn_metadata.prefill_wrapper, - BatchPrefillWithPagedKVCacheWrapper, - ) - attn_metadata.prefill_wrapper.plan( - qo_indptr_cpu, - paged_kv_indptr_cpu, - paged_kv_indices, - paged_kv_last_page_len_cpu[prefill_start:], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=self.sm_scale, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.prefill_fixed_split_size, - disable_split_kv=self.disable_split_kv, - ) + if prefill_use_trtllm: + # Create GPU versions + qo_indptr_prefill_gpu = ( + qo_indptr[prefill_start:] - qo_indptr[prefill_start] + ) + paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[ + prefill_start : num_reqs + 1 + ] + # Compute max_q_len for prefill requests + query_lens_prefill_cpu = ( + qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1] + ) + max_q_len_prefill = int(query_lens_prefill_cpu.max().item()) + attn_metadata.prefill = TRTLLMPrefill( + block_tables=block_table_tensor[prefill_start:], + seq_lens=seq_lens[prefill_start:], + cum_seq_lens_q=qo_indptr_prefill_gpu, + cum_seq_lens_kv=paged_kv_indptr_prefill_gpu, + max_q_len=max_q_len_prefill, + max_seq_len=max_seq_len, + ) + else: + prefill_wrapper = self._get_prefill_wrapper() + # Slicing CPU buffers that are only needed for FI native prefills + paged_kv_last_page_len_prefill_cpu = self.paged_kv_last_page_len.cpu[ + prefill_start:num_reqs + ] + assert paged_kv_last_page_len_prefill_cpu.shape[0] == num_prefills + paged_kv_indptr_prefill_cpu = self.paged_kv_indptr.cpu[ + prefill_start : num_reqs + 1 + ] + assert paged_kv_indptr_prefill_cpu.shape[0] == num_prefills + 1 + if self.use_dcp: + assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper) + prefill_wrapper.plan( + qo_indptr_cpu=qo_indptr_prefill_cpu, + paged_kv_indptr_cpu=paged_kv_indptr_prefill_cpu, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len_cpu=paged_kv_last_page_len_prefill_cpu, + page_size=self.page_size, + num_qo_heads=self.num_qo_heads, + dcp_world_size=self.dcp_world_size, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_cache_dtype=self.kv_cache_dtype, + prefill_fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) else: - attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( - self.device, non_blocking=True + assert isinstance( + prefill_wrapper, + BatchPrefillWithPagedKVCacheWrapper, ) - attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( - self.device, non_blocking=True + prefill_wrapper.plan( + qo_indptr_prefill_cpu, + paged_kv_indptr_prefill_cpu, + paged_kv_indices, + paged_kv_last_page_len_prefill_cpu, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, ) + attn_metadata.prefill = FIPrefill(wrapper=prefill_wrapper) - if num_decodes > 0: + ## DECODE PATHWAY + if num_decodes > 0: + if decode_use_trtllm: + assert num_decode_tokens % num_decodes == 0, ( + "TRTLLM decode requires uniform query lengths per request." + ) + attn_metadata.decode = TRTLLMDecode( + block_tables=block_table_tensor[:num_decodes], + seq_lens=seq_lens[:num_decodes], + max_seq_len=max_seq_len, + ) + else: pure_decode = num_prefills == 0 use_cudagraph = ( self.enable_cuda_graph @@ -945,33 +1072,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) num_input_tokens = num_decode_tokens - attn_metadata.decode_wrapper = self._get_decode_wrapper( + decode_wrapper = self._get_decode_wrapper( num_input_tokens, use_cudagraph ) - if not attn_metadata.decode_use_trtllm: - # Use the persistent buffer with padding length, - # instead of the same address but chunked version - # in atten_metadata when using cudagraph. - fast_plan_decode( - attn_metadata.decode_wrapper, - self.paged_kv_indptr_cpu[: num_input_tokens + 1], - paged_kv_indices, - self.paged_kv_last_page_len_cpu[:num_input_tokens], - seq_lens_cpu[:num_input_tokens], - self.num_qo_heads * self.dcp_world_size, - self.num_kv_heads, - self.head_dim, - self.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - sm_scale=self.sm_scale, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.decode_fixed_split_size, - disable_split_kv=self.disable_split_kv, - ) + # Use the persistent buffer with padding length, + # instead of the same address but chunked version + # in atten_metadata when using cudagraph. + fast_plan_decode( + decode_wrapper, + self.paged_kv_indptr.cpu[: num_input_tokens + 1], + paged_kv_indices, + self.paged_kv_last_page_len.cpu[:num_input_tokens], + seq_lens_cpu[:num_input_tokens], + self.num_qo_heads * self.dcp_world_size, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.decode_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) + attn_metadata.decode = FIDecode(wrapper=decode_wrapper) return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: @@ -1104,6 +1231,9 @@ class FlashInferImpl(AttentionImpl): if self.bmm2_scale is None: self.bmm2_scale = layer._v_scale_float + prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill) + decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode) + # The attn+quant fusion happens when output_scale is provided. if output_scale is None: assert output_block_scale is None, ( @@ -1113,8 +1243,8 @@ class FlashInferImpl(AttentionImpl): assert attn_metadata.q_data_type == FP8_DTYPE, ( "Query must be FP8 when attn+quant fusion happened." ) - assert ( - attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm + assert (attn_metadata.num_prefills == 0 or prefill_use_trtllm) and ( + attn_metadata.num_decodes == 0 or decode_use_trtllm ), "Must use TRT-LLM attn" if output.dtype == FP8_DTYPE: @@ -1191,22 +1321,25 @@ class FlashInferImpl(AttentionImpl): # When using spec decoding, num_decodes can be < num_decode_tokens # because some decode requests may have more than one query token. - num_decodes = attn_metadata.num_decodes num_decode_tokens = attn_metadata.num_decode_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens stride_order = FlashInferBackend.get_kv_cache_stride_order() kv_cache_permute = kv_cache.permute(*stride_order) + + use_dcp = self.dcp_world_size > 1 + # Regular attention (common case). # Decodes are at the front and prefills are at the back. if num_prefill_tokens > 0: - prefill_wrapper = attn_metadata.prefill_wrapper prefill_query = query[num_decode_tokens:] assert prefill_query.shape[0] == num_prefill_tokens - assert prefill_wrapper is not None - if not attn_metadata.prefill_use_trtllm: - if self.dcp_world_size > 1: + if not prefill_use_trtllm: + assert isinstance(attn_metadata.prefill, FIPrefill) + prefill_wrapper = attn_metadata.prefill.wrapper + assert prefill_wrapper is not None + if use_dcp: assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper) assert prefill_wrapper._context._window_left == self.window_left assert prefill_wrapper._context._logits_soft_cap == ( @@ -1247,11 +1380,12 @@ class FlashInferImpl(AttentionImpl): out=output[num_decode_tokens:], ) else: + assert isinstance(attn_metadata.prefill, TRTLLMPrefill) # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() workspace_buffer = _get_trtllm_gen_workspace_buffer() - block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:] - seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] + block_tables_prefill = attn_metadata.prefill.block_tables + seq_lens_prefill = attn_metadata.prefill.seq_lens # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND assert get_kv_cache_layout() == "HND" @@ -1298,13 +1432,13 @@ class FlashInferImpl(AttentionImpl): workspace_buffer=workspace_buffer, block_tables=mock_block_table, seq_lens=seq_lens_prefill, - max_q_len=attn_metadata.max_q_len_prefill, - max_kv_len=attn_metadata.max_seq_len, + max_q_len=attn_metadata.prefill.max_q_len, + max_kv_len=attn_metadata.prefill.max_seq_len, bmm1_scale=self.bmm1_scale, bmm2_scale=self.bmm2_scale, batch_size=attn_metadata.num_prefills, - cum_seq_lens_q=attn_metadata.qo_indptr_gpu, - cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, + cum_seq_lens_q=attn_metadata.prefill.cum_seq_lens_q, + cum_seq_lens_kv=attn_metadata.prefill.cum_seq_lens_kv, window_left=self.window_left, sinks=self.sinks, o_sf_scale=self.o_sf_scale, @@ -1312,17 +1446,18 @@ class FlashInferImpl(AttentionImpl): ) if num_decode_tokens > 0: - decode_wrapper = attn_metadata.decode_wrapper decode_query = query[:num_decode_tokens] assert decode_query.shape[0] == num_decode_tokens - assert decode_wrapper is not None - if not attn_metadata.decode_use_trtllm: + if not decode_use_trtllm: + assert isinstance(attn_metadata.decode, FIDecode) + decode_wrapper = attn_metadata.decode.wrapper + assert decode_wrapper is not None assert decode_wrapper._window_left == self.window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale - if self.dcp_world_size > 1: + if use_dcp: decode_query = get_dcp_group().all_gather( decode_query.contiguous(), dim=-2 ) @@ -1357,12 +1492,11 @@ class FlashInferImpl(AttentionImpl): ) else: # decode_query may be non-contiguous + assert isinstance(attn_metadata.decode, TRTLLMDecode) decode_query = decode_query.contiguous() workspace_buffer = _get_trtllm_gen_workspace_buffer() - block_tables_decode = attn_metadata.block_table_tensor[ - :num_decode_tokens - ] - seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] + block_tables_decode = attn_metadata.decode.block_tables + seq_lens_decode = attn_metadata.decode.seq_lens # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND assert get_kv_cache_layout() == "HND" @@ -1397,7 +1531,7 @@ class FlashInferImpl(AttentionImpl): workspace_buffer=workspace_buffer, block_tables=block_tables_decode, seq_lens=seq_lens_decode, - max_seq_len=attn_metadata.max_seq_len, + max_seq_len=attn_metadata.decode.max_seq_len, bmm1_scale=self.bmm1_scale, bmm2_scale=self.bmm2_scale, window_left=self.window_left,