[Cleanup] Refactor FlashInferMetadataBuilder (#29128)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Benjamin Chislett 2025-12-18 17:45:30 -05:00 committed by GitHub
parent 6ca74bc11a
commit d6b3d39b6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,6 +16,7 @@ from flashinfer import (
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor from flashinfer.utils import FP4Tensor
from typing_extensions import override
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
@ -59,6 +60,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.utils import CpuGpuBuffer
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
@ -181,7 +183,6 @@ class BatchDCPPrefillWrapper:
paged_kv_indptr_cpu: torch.Tensor, paged_kv_indptr_cpu: torch.Tensor,
paged_kv_indices: torch.Tensor, paged_kv_indices: torch.Tensor,
paged_kv_last_page_len_cpu: torch.Tensor, paged_kv_last_page_len_cpu: torch.Tensor,
prefill_start: int,
page_size: int, page_size: int,
num_qo_heads: int, num_qo_heads: int,
dcp_world_size: int, dcp_world_size: int,
@ -200,7 +201,7 @@ class BatchDCPPrefillWrapper:
qo_indptr_cpu, qo_indptr_cpu,
paged_kv_indptr_cpu, paged_kv_indptr_cpu,
paged_kv_indices, paged_kv_indices,
paged_kv_last_page_len_cpu[prefill_start:], paged_kv_last_page_len_cpu,
num_qo_heads * dcp_world_size, num_qo_heads * dcp_world_size,
num_kv_heads, num_kv_heads,
head_dim, head_dim,
@ -380,40 +381,103 @@ class FlashInferBackend(AttentionBackend):
@dataclass @dataclass
class FlashInferMetadata: class FIPrefill:
num_actual_tokens: int # Number of tokens excluding padding. """Metadata for the native FlashInfer prefill pathway (non-TRTLLM)."""
# The data type of the query wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper
q_data_type: torch.dtype
@dataclass
class FIDecode:
"""Metadata for the native FlashInfer decode pathway (non-TRTLLM)."""
wrapper: BatchDecodeWithPagedKVCacheWrapper
@dataclass
class TRTLLMPrefill:
"""Metadata for the TRTLLM prefill pathway."""
block_tables: torch.Tensor
"""
The slice of the block table tensor corresponding *only* to prefill requests.
Shape: [num_prefills, max_num_blocks_per_seq]
"""
seq_lens: torch.Tensor
"""
The slice of the sequence lengths tensor corresponding *only* to prefill requests.
Shape: [num_prefills]
"""
cum_seq_lens_q: torch.Tensor
cum_seq_lens_kv: torch.Tensor
max_q_len: int
"""
The maximum query length *among prefill requests*.
"""
max_seq_len: int
"""The maximum sequence length for KV Cache."""
@dataclass
class TRTLLMDecode:
"""Metadata for the TRTLLM decode pathway."""
block_tables: torch.Tensor
"""
The slice of the block table tensor corresponding *only* to decode requests.
Shape: [num_decodes, max_num_blocks_per_seq]
"""
seq_lens: torch.Tensor
"""
The slice of the sequence lengths tensor corresponding *only* to decode requests.
Shape: [num_decodes]
"""
max_seq_len: int
"""The maximum sequence length for KV Cache."""
@dataclass
class FlashInferMetadata:
num_actual_tokens: int
"""Total number of tokens in the batch (excluding padding)."""
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
"""Tensor for writing K/V to the cache. Shape: [num_actual_tokens]"""
# For flashinfer trtllm batch decode q_data_type: torch.dtype
max_q_len: int
max_q_len_prefill: int
max_seq_len: int
seq_lens: torch.Tensor
block_table_tensor: torch.Tensor
prefill_use_trtllm: bool
decode_use_trtllm: bool
# For handling prefill decode split
num_decodes: int num_decodes: int
num_decode_tokens: int num_decode_tokens: int
num_prefills: int num_prefills: int
num_prefill_tokens: int num_prefill_tokens: int
# For cascade attention (CPU for planning). prefill: FIPrefill | TRTLLMPrefill | None
"""
Holds the metadata for the prefill portion of the batch.
Will be `None` if `num_prefill_tokens == 0`.
"""
decode: FIDecode | TRTLLMDecode | None
"""
Holds the metadata for the decode portion of the batch.
Will be `None` if `num_decode_tokens == 0`.
"""
# --- Special Case: Cascade Attention ---
use_cascade: bool use_cascade: bool
"""
If True, the entire batch is a cascade attention call, and the
`prefill` and `decode` fields will both be None.
"""
prefill_wrapper: ( cascade_wrapper: MultiLevelCascadeAttentionWrapper | None
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
) = None
decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
qo_indptr_gpu: torch.Tensor | None = None
paged_kv_indptr_gpu: torch.Tensor | None = None
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
@ -482,6 +546,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.dcp_world_size = 1 self.dcp_world_size = 1
self.dcp_rank = 0 self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = 1 self.dcp_kv_cache_interleave_size = 1
self.use_dcp = self.dcp_world_size > 1
self.num_qo_heads = self.model_config.get_num_attention_heads( self.num_qo_heads = self.model_config.get_num_attention_heads(
self.vllm_config.parallel_config self.vllm_config.parallel_config
@ -535,34 +600,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"sinks, please use trtllm on blackwell or flash attention on " "sinks, please use trtllm on blackwell or flash attention on "
"earlier GPUs." "earlier GPUs."
) )
# Preparing persistent buffers (device-side) # Preparing persistent buffers
self.paged_kv_indptr = torch.zeros( self.pin_memory = is_pin_memory_available()
max_num_reqs + 1, dtype=torch.int32, device=self.device self.paged_kv_indptr = self._make_buffer(max_num_reqs + 1)
) self.paged_kv_indptr_cpu_buffer = torch.zeros_like(
self.paged_kv_indices = torch.zeros( self.paged_kv_indptr.cpu, pin_memory=self.pin_memory
max_num_pages, # max num pages possible ) # Extra buffer for mutable paged_kv_indptr.cpu in cuda graph mode
dtype=torch.int32, self.paged_kv_indices = self._make_buffer(max_num_pages)
device=self.device, self.paged_kv_last_page_len = self._make_buffer(max_num_reqs)
)
self.paged_kv_last_page_len = torch.zeros(
max_num_reqs, dtype=torch.int32, device=self.device
)
# host-side buffer
pin_memory = is_pin_memory_available()
self.paged_kv_indptr_cpu = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
self.paged_kv_indptr_buffer = torch.zeros_like(
self.paged_kv_indptr_cpu, pin_memory=pin_memory
)
self.paged_kv_indices_cpu = torch.zeros(
max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.paged_kv_last_page_len_cpu = torch.zeros(
max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
if self.head_dim == 256 and current_platform.is_device_capability_family(100): if self.head_dim == 256 and current_platform.is_device_capability_family(100):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
@ -573,6 +618,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"passing --block-size 32 or --block-size 64." "passing --block-size 32 or --block-size 64."
) )
def _make_buffer(
self, *size: int | torch.SymInt, dtype: torch.dtype = torch.int32
) -> CpuGpuBuffer:
return CpuGpuBuffer(
*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory,
with_numpy=True,
)
@override # type: ignore[misc]
@classmethod @classmethod
def get_cudagraph_support( def get_cudagraph_support(
cls: type["FlashInferMetadataBuilder"], cls: type["FlashInferMetadataBuilder"],
@ -607,7 +664,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self, self,
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper: ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
if self._prefill_wrapper is None: if self._prefill_wrapper is None:
if self.dcp_world_size > 1: if self.use_dcp:
self._prefill_wrapper = BatchDCPPrefillWrapper( self._prefill_wrapper = BatchDCPPrefillWrapper(
workspace_buffer=self._get_workspace_buffer(), workspace_buffer=self._get_workspace_buffer(),
) )
@ -626,9 +683,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if decode_wrapper is None: if decode_wrapper is None:
if use_cudagraph: if use_cudagraph:
paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1] paged_kv_indptr = self.paged_kv_indptr.gpu[: batch_size + 1]
paged_kv_indices = self.paged_kv_indices paged_kv_indices = self.paged_kv_indices.gpu
paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size] paged_kv_last_page_len = self.paged_kv_last_page_len.gpu[:batch_size]
else: else:
paged_kv_indptr = None paged_kv_indptr = None
paged_kv_indices = None paged_kv_indices = None
@ -661,6 +718,60 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) )
return self._cascade_wrapper return self._cascade_wrapper
def _compute_flashinfer_kv_metadata(
self,
num_blocks_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table_tensor: torch.Tensor,
num_reqs: int,
page_size: int,
) -> torch.Tensor:
"""
Compute paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len for FlashInfer
attention.
Results are stored in self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len buffers.
Returns paged_kv_indices, a GPU tensor with shape [num_actual_pages].
"""
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
np.cumsum(
num_blocks_np,
dtype=np.int32,
out=self.paged_kv_indptr.np[1 : num_reqs + 1],
)
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
# after this line (e.g., for cuda graphs), we need to copy the data to
# self.paged_kv_indptr_buffer to avoid race condition.
self.paged_kv_indptr_cpu_buffer[: num_reqs + 1] = self.paged_kv_indptr.cpu[
: num_reqs + 1
]
paged_kv_indptr = self.paged_kv_indptr.gpu[: num_reqs + 1]
paged_kv_indptr.copy_(
self.paged_kv_indptr_cpu_buffer[: num_reqs + 1], non_blocking=True
)
# write self.paged_kv_indices inplace
num_actual_pages = self.paged_kv_indptr.np[num_reqs]
paged_kv_indices = self.paged_kv_indices.gpu[:num_actual_pages]
_copy_page_indices_kernel[(num_reqs,)](
paged_kv_indices,
block_table_tensor,
block_table_tensor.stride(0),
paged_kv_indptr,
BLOCK_SIZE=1024,
)
# write self.paged_kv_last_page_len_cpu inplace
paged_kv_last_page_len_np = seq_lens_np % page_size
self.paged_kv_last_page_len.np[:num_reqs] = np.where(
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
page_size,
paged_kv_last_page_len_np,
)
return paged_kv_indices
def build( def build(
self, self,
common_prefix_len: int, common_prefix_len: int,
@ -678,98 +789,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) )
page_size = self.page_size page_size = self.page_size
max_q_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len max_seq_len = common_attn_metadata.max_seq_len
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor block_table_tensor = common_attn_metadata.block_table_tensor
qo_indptr = common_attn_metadata.query_start_loc
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
if self.dcp_world_size > 1: # Step 1: Decide which dispatch modes to use:
if num_prefills > 0: # - Cascade attention (distinct mode)
qo_indptr_prefill_cpu = ( # - Prefill (FI native or TRTLLM)
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes] # - Decode (FI native or TRTLLM)
)
query_lens_prefill_cpu = (
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
)
seq_lens_cpu[num_decodes:] = (
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
)
seq_lens_cpu = get_dcp_local_seq_lens(
seq_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
)
seq_lens_np = seq_lens_cpu.numpy()
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
use_cascade = common_prefix_len > 0 use_cascade = common_prefix_len > 0
if use_cascade:
# Grab the blocks of the shared prefix from the first request.
assert common_prefix_len % page_size == 0
num_common_kv_blocks = common_prefix_len // page_size
# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device="cpu"
)
shared_kv_page_indptr_cpu = torch.tensor(
[0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
)
shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
shared_kv_last_page_len_cpu = torch.tensor(
[page_size], dtype=torch.int32, device="cpu"
)
# Remove the blocks of the shared prefix from all requests.
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
num_blocks_np -= num_common_kv_blocks
else:
shared_qo_indptr_cpu = None
shared_kv_page_indptr_cpu = None
shared_kv_page_indices_cpu = None
shared_kv_last_page_len_cpu = None
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
np.cumsum(
num_blocks_np,
dtype=np.int32,
out=self.paged_kv_indptr_np[1 : num_reqs + 1],
)
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
# after this line (e.g., for cuda graphs), we need to copy the data to
# self.paged_kv_indptr_buffer to avoid race condition.
self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[
: num_reqs + 1
]
paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1]
paged_kv_indptr.copy_(
self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True
)
# write self.paged_kv_indices inplace
num_actual_pages = self.paged_kv_indptr_np[num_reqs]
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
_copy_page_indices_kernel[(num_reqs,)](
paged_kv_indices,
block_table_tensor,
block_table_tensor.stride(0),
paged_kv_indptr,
BLOCK_SIZE=1024,
)
# write self.paged_kv_last_page_len_cpu inplace
paged_kv_last_page_len_np = seq_lens_np % page_size
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
page_size,
paged_kv_last_page_len_np,
)
uses_spec_reorder = self.reorder_batch_threshold > 1 uses_spec_reorder = self.reorder_batch_threshold > 1
prefill_use_trtllm = use_trtllm_attention( prefill_use_trtllm = use_trtllm_attention(
self.num_qo_heads, self.num_qo_heads,
@ -788,7 +818,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.use_trtllm_decode_attention and self.dcp_world_size <= 1 self.use_trtllm_decode_attention and self.dcp_world_size <= 1
) )
if not (prefill_use_trtllm and decode_use_trtllm): all_uses_trtllm = (num_prefills == 0 or prefill_use_trtllm) and (
num_decodes == 0 or decode_use_trtllm
)
is_only_trtllm_decode = num_prefills == 0 and (
num_decodes > 0 and decode_use_trtllm
)
if not all_uses_trtllm:
if self.has_sinks: if self.has_sinks:
raise NotImplementedError( raise NotImplementedError(
"FlashInfer backend currently does not support attention " "FlashInfer backend currently does not support attention "
@ -813,28 +850,102 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# fall back to model dtype. # fall back to model dtype.
self.q_data_type = self.model_config.dtype self.q_data_type = self.model_config.dtype
# Step 2: Initialize the output metadata
# Leave prefill/decode/cascade_wrapper empty, to be populated
# case by case depending on the batch contents and backend selection.
attn_metadata = FlashInferMetadata( attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
q_data_type=self.q_data_type,
slot_mapping=common_attn_metadata.slot_mapping, slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len, q_data_type=self.q_data_type,
max_q_len_prefill=max_q_len,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
prefill_use_trtllm=prefill_use_trtllm,
decode_use_trtllm=decode_use_trtllm,
num_decodes=num_decodes, num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills, num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
use_cascade=use_cascade, use_cascade=use_cascade,
prefill=None,
decode=None,
cascade_wrapper=None,
) )
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] # Guard access to seq_lens_cpu, which may not always be needed
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] # and can be expensive to retrieve in async mode.
needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
num_blocks_np = (
(seq_lens_np + (page_size - 1)) // page_size
if seq_lens_np is not None
else None
)
# Adjust seq_lens_cpu for DCP
if self.use_dcp:
assert seq_lens_cpu is not None
if num_prefills > 0:
qo_indptr_prefill_cpu = (
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
)
query_lens_prefill_cpu = (
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
)
seq_lens_cpu[num_decodes:] = (
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
)
seq_lens_cpu = get_dcp_local_seq_lens(
seq_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
)
# Adjust num_block_np for cascade attention
if use_cascade:
assert num_blocks_np is not None
assert common_prefix_len % page_size == 0
num_common_kv_blocks = common_prefix_len // page_size
num_blocks_np -= num_common_kv_blocks
# Compute paged_kv_indices if necessary
needs_paged_kv_indices = use_cascade or not is_only_trtllm_decode
if needs_paged_kv_indices:
assert num_blocks_np is not None
assert seq_lens_np is not None
paged_kv_indices = self._compute_flashinfer_kv_metadata(
num_blocks_np,
seq_lens_np,
block_table_tensor,
num_reqs,
page_size,
)
else:
paged_kv_indices = None
# Early-out for cascade attention
if use_cascade:
# Grab the blocks of the shared prefix from the first request.
num_common_kv_blocks = common_prefix_len // page_size
# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device="cpu"
)
shared_kv_page_indptr_cpu = torch.tensor(
[0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
)
shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
shared_kv_last_page_len_cpu = torch.tensor(
[page_size], dtype=torch.int32, device="cpu"
)
# Remove the blocks of the shared prefix from all requests.
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
num_blocks_np -= num_common_kv_blocks
assert paged_kv_indices is not None
paged_kv_indptr_cpu = self.paged_kv_indptr.cpu[: 1 + num_reqs]
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len.cpu[:num_reqs]
if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan( attn_metadata.cascade_wrapper.plan(
[shared_qo_indptr_cpu, qo_indptr_cpu], [shared_qo_indptr_cpu, qo_indptr_cpu],
@ -852,91 +963,107 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype, kv_data_type=self.kv_cache_dtype,
) )
else: return attn_metadata
# Regular attention (common case).
# Decodes are at the front and prefills are at the back.
num_prefills = attn_metadata.num_prefills
num_decodes = attn_metadata.num_decodes
if num_prefills > 0:
# Decodes are first so prefills start after the last decode
prefill_start = num_decodes
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
assert (
paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills
)
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
qo_indptr_cpu = (
qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
)
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
# Recompute max_q_len for the slice of requests we are using # Step 3: Handle prefill and decode pathways case by case
# for prefills. This can be different from max_q_len when ## PREFILL PATHWAY
# we have a non-uniform batch with some short decodes offloaded if num_prefills > 0:
# to the prefill pathway # Slices for shared prefill metadata
query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] prefill_start = num_decodes
attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) qo_indptr_prefill_cpu = (
qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
)
assert qo_indptr_prefill_cpu.shape[0] == num_prefills + 1
if not attn_metadata.prefill_use_trtllm: if prefill_use_trtllm:
if self.dcp_world_size > 1: # Create GPU versions
assert isinstance( qo_indptr_prefill_gpu = (
attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper qo_indptr[prefill_start:] - qo_indptr[prefill_start]
) )
attn_metadata.prefill_wrapper.plan( paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[
qo_indptr_cpu=qo_indptr_cpu, prefill_start : num_reqs + 1
paged_kv_indptr_cpu=paged_kv_indptr_cpu, ]
paged_kv_indices=paged_kv_indices, # Compute max_q_len for prefill requests
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu, query_lens_prefill_cpu = (
prefill_start=prefill_start, qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
page_size=self.page_size, )
num_qo_heads=self.num_qo_heads, max_q_len_prefill = int(query_lens_prefill_cpu.max().item())
dcp_world_size=self.dcp_world_size, attn_metadata.prefill = TRTLLMPrefill(
num_kv_heads=self.num_kv_heads, block_tables=block_table_tensor[prefill_start:],
head_dim=self.head_dim, seq_lens=seq_lens[prefill_start:],
sm_scale=self.sm_scale, cum_seq_lens_q=qo_indptr_prefill_gpu,
window_left=self.window_left, cum_seq_lens_kv=paged_kv_indptr_prefill_gpu,
logits_soft_cap=self.logits_soft_cap, max_q_len=max_q_len_prefill,
q_data_type=self.q_data_type, max_seq_len=max_seq_len,
kv_cache_dtype=self.kv_cache_dtype, )
prefill_fixed_split_size=self.prefill_fixed_split_size, else:
disable_split_kv=self.disable_split_kv, prefill_wrapper = self._get_prefill_wrapper()
) # Slicing CPU buffers that are only needed for FI native prefills
else: paged_kv_last_page_len_prefill_cpu = self.paged_kv_last_page_len.cpu[
assert isinstance( prefill_start:num_reqs
attn_metadata.prefill_wrapper, ]
BatchPrefillWithPagedKVCacheWrapper, assert paged_kv_last_page_len_prefill_cpu.shape[0] == num_prefills
) paged_kv_indptr_prefill_cpu = self.paged_kv_indptr.cpu[
attn_metadata.prefill_wrapper.plan( prefill_start : num_reqs + 1
qo_indptr_cpu, ]
paged_kv_indptr_cpu, assert paged_kv_indptr_prefill_cpu.shape[0] == num_prefills + 1
paged_kv_indices, if self.use_dcp:
paged_kv_last_page_len_cpu[prefill_start:], assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
self.num_qo_heads, prefill_wrapper.plan(
self.num_kv_heads, qo_indptr_cpu=qo_indptr_prefill_cpu,
self.head_dim, paged_kv_indptr_cpu=paged_kv_indptr_prefill_cpu,
self.page_size, paged_kv_indices=paged_kv_indices,
causal=True, paged_kv_last_page_len_cpu=paged_kv_last_page_len_prefill_cpu,
sm_scale=self.sm_scale, page_size=self.page_size,
window_left=self.window_left, num_qo_heads=self.num_qo_heads,
logits_soft_cap=self.logits_soft_cap, dcp_world_size=self.dcp_world_size,
q_data_type=self.q_data_type, num_kv_heads=self.num_kv_heads,
kv_data_type=self.kv_cache_dtype, head_dim=self.head_dim,
fixed_split_size=self.prefill_fixed_split_size, sm_scale=self.sm_scale,
disable_split_kv=self.disable_split_kv, window_left=self.window_left,
) logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_cache_dtype=self.kv_cache_dtype,
prefill_fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
else: else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( assert isinstance(
self.device, non_blocking=True prefill_wrapper,
BatchPrefillWithPagedKVCacheWrapper,
) )
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( prefill_wrapper.plan(
self.device, non_blocking=True qo_indptr_prefill_cpu,
paged_kv_indptr_prefill_cpu,
paged_kv_indices,
paged_kv_last_page_len_prefill_cpu,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
) )
attn_metadata.prefill = FIPrefill(wrapper=prefill_wrapper)
if num_decodes > 0: ## DECODE PATHWAY
if num_decodes > 0:
if decode_use_trtllm:
assert num_decode_tokens % num_decodes == 0, (
"TRTLLM decode requires uniform query lengths per request."
)
attn_metadata.decode = TRTLLMDecode(
block_tables=block_table_tensor[:num_decodes],
seq_lens=seq_lens[:num_decodes],
max_seq_len=max_seq_len,
)
else:
pure_decode = num_prefills == 0 pure_decode = num_prefills == 0
use_cudagraph = ( use_cudagraph = (
self.enable_cuda_graph self.enable_cuda_graph
@ -945,33 +1072,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) )
num_input_tokens = num_decode_tokens num_input_tokens = num_decode_tokens
attn_metadata.decode_wrapper = self._get_decode_wrapper( decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph num_input_tokens, use_cudagraph
) )
if not attn_metadata.decode_use_trtllm: # Use the persistent buffer with padding length,
# Use the persistent buffer with padding length, # instead of the same address but chunked version
# instead of the same address but chunked version # in atten_metadata when using cudagraph.
# in atten_metadata when using cudagraph. fast_plan_decode(
fast_plan_decode( decode_wrapper,
attn_metadata.decode_wrapper, self.paged_kv_indptr.cpu[: num_input_tokens + 1],
self.paged_kv_indptr_cpu[: num_input_tokens + 1], paged_kv_indices,
paged_kv_indices, self.paged_kv_last_page_len.cpu[:num_input_tokens],
self.paged_kv_last_page_len_cpu[:num_input_tokens], seq_lens_cpu[:num_input_tokens],
seq_lens_cpu[:num_input_tokens], self.num_qo_heads * self.dcp_world_size,
self.num_qo_heads * self.dcp_world_size, self.num_kv_heads,
self.num_kv_heads, self.head_dim,
self.head_dim, self.page_size,
self.page_size, # Disable flashinfer's pos encoding and use vllm's rope.
# Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE",
pos_encoding_mode="NONE", sm_scale=self.sm_scale,
sm_scale=self.sm_scale, window_left=self.window_left,
window_left=self.window_left, logits_soft_cap=self.logits_soft_cap,
logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type,
q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype,
kv_data_type=self.kv_cache_dtype, fixed_split_size=self.decode_fixed_split_size,
fixed_split_size=self.decode_fixed_split_size, disable_split_kv=self.disable_split_kv,
disable_split_kv=self.disable_split_kv, )
) attn_metadata.decode = FIDecode(wrapper=decode_wrapper)
return attn_metadata return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
@ -1104,6 +1231,9 @@ class FlashInferImpl(AttentionImpl):
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float self.bmm2_scale = layer._v_scale_float
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode)
# The attn+quant fusion happens when output_scale is provided. # The attn+quant fusion happens when output_scale is provided.
if output_scale is None: if output_scale is None:
assert output_block_scale is None, ( assert output_block_scale is None, (
@ -1113,8 +1243,8 @@ class FlashInferImpl(AttentionImpl):
assert attn_metadata.q_data_type == FP8_DTYPE, ( assert attn_metadata.q_data_type == FP8_DTYPE, (
"Query must be FP8 when attn+quant fusion happened." "Query must be FP8 when attn+quant fusion happened."
) )
assert ( assert (attn_metadata.num_prefills == 0 or prefill_use_trtllm) and (
attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm attn_metadata.num_decodes == 0 or decode_use_trtllm
), "Must use TRT-LLM attn" ), "Must use TRT-LLM attn"
if output.dtype == FP8_DTYPE: if output.dtype == FP8_DTYPE:
@ -1191,22 +1321,25 @@ class FlashInferImpl(AttentionImpl):
# When using spec decoding, num_decodes can be < num_decode_tokens # When using spec decoding, num_decodes can be < num_decode_tokens
# because some decode requests may have more than one query token. # because some decode requests may have more than one query token.
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
stride_order = FlashInferBackend.get_kv_cache_stride_order() stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order) kv_cache_permute = kv_cache.permute(*stride_order)
use_dcp = self.dcp_world_size > 1
# Regular attention (common case). # Regular attention (common case).
# Decodes are at the front and prefills are at the back. # Decodes are at the front and prefills are at the back.
if num_prefill_tokens > 0: if num_prefill_tokens > 0:
prefill_wrapper = attn_metadata.prefill_wrapper
prefill_query = query[num_decode_tokens:] prefill_query = query[num_decode_tokens:]
assert prefill_query.shape[0] == num_prefill_tokens assert prefill_query.shape[0] == num_prefill_tokens
assert prefill_wrapper is not None
if not attn_metadata.prefill_use_trtllm: if not prefill_use_trtllm:
if self.dcp_world_size > 1: assert isinstance(attn_metadata.prefill, FIPrefill)
prefill_wrapper = attn_metadata.prefill.wrapper
assert prefill_wrapper is not None
if use_dcp:
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper) assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
assert prefill_wrapper._context._window_left == self.window_left assert prefill_wrapper._context._window_left == self.window_left
assert prefill_wrapper._context._logits_soft_cap == ( assert prefill_wrapper._context._logits_soft_cap == (
@ -1247,11 +1380,12 @@ class FlashInferImpl(AttentionImpl):
out=output[num_decode_tokens:], out=output[num_decode_tokens:],
) )
else: else:
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
# prefill_query may be non-contiguous # prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous() prefill_query = prefill_query.contiguous()
workspace_buffer = _get_trtllm_gen_workspace_buffer() workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:] block_tables_prefill = attn_metadata.prefill.block_tables
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] seq_lens_prefill = attn_metadata.prefill.seq_lens
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND" assert get_kv_cache_layout() == "HND"
@ -1298,13 +1432,13 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer=workspace_buffer, workspace_buffer=workspace_buffer,
block_tables=mock_block_table, block_tables=mock_block_table,
seq_lens=seq_lens_prefill, seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len_prefill, max_q_len=attn_metadata.prefill.max_q_len,
max_kv_len=attn_metadata.max_seq_len, max_kv_len=attn_metadata.prefill.max_seq_len,
bmm1_scale=self.bmm1_scale, bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale, bmm2_scale=self.bmm2_scale,
batch_size=attn_metadata.num_prefills, batch_size=attn_metadata.num_prefills,
cum_seq_lens_q=attn_metadata.qo_indptr_gpu, cum_seq_lens_q=attn_metadata.prefill.cum_seq_lens_q,
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, cum_seq_lens_kv=attn_metadata.prefill.cum_seq_lens_kv,
window_left=self.window_left, window_left=self.window_left,
sinks=self.sinks, sinks=self.sinks,
o_sf_scale=self.o_sf_scale, o_sf_scale=self.o_sf_scale,
@ -1312,17 +1446,18 @@ class FlashInferImpl(AttentionImpl):
) )
if num_decode_tokens > 0: if num_decode_tokens > 0:
decode_wrapper = attn_metadata.decode_wrapper
decode_query = query[:num_decode_tokens] decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None
if not attn_metadata.decode_use_trtllm: if not decode_use_trtllm:
assert isinstance(attn_metadata.decode, FIDecode)
decode_wrapper = attn_metadata.decode.wrapper
assert decode_wrapper is not None
assert decode_wrapper._window_left == self.window_left assert decode_wrapper._window_left == self.window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
assert decode_wrapper._sm_scale == self.scale assert decode_wrapper._sm_scale == self.scale
if self.dcp_world_size > 1: if use_dcp:
decode_query = get_dcp_group().all_gather( decode_query = get_dcp_group().all_gather(
decode_query.contiguous(), dim=-2 decode_query.contiguous(), dim=-2
) )
@ -1357,12 +1492,11 @@ class FlashInferImpl(AttentionImpl):
) )
else: else:
# decode_query may be non-contiguous # decode_query may be non-contiguous
assert isinstance(attn_metadata.decode, TRTLLMDecode)
decode_query = decode_query.contiguous() decode_query = decode_query.contiguous()
workspace_buffer = _get_trtllm_gen_workspace_buffer() workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_decode = attn_metadata.block_table_tensor[ block_tables_decode = attn_metadata.decode.block_tables
:num_decode_tokens seq_lens_decode = attn_metadata.decode.seq_lens
]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND" assert get_kv_cache_layout() == "HND"
@ -1397,7 +1531,7 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer=workspace_buffer, workspace_buffer=workspace_buffer,
block_tables=block_tables_decode, block_tables=block_tables_decode,
seq_lens=seq_lens_decode, seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len, max_seq_len=attn_metadata.decode.max_seq_len,
bmm1_scale=self.bmm1_scale, bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale, bmm2_scale=self.bmm2_scale,
window_left=self.window_left, window_left=self.window_left,