[Cleanup] Refactor FlashInferMetadataBuilder (#29128)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Benjamin Chislett 2025-12-18 17:45:30 -05:00 committed by GitHub
parent 6ca74bc11a
commit d6b3d39b6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,6 +16,7 @@ from flashinfer import (
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from typing_extensions import override
from vllm import envs
from vllm.attention.backends.abstract import (
@ -59,6 +60,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.utils import CpuGpuBuffer
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
@ -181,7 +183,6 @@ class BatchDCPPrefillWrapper:
paged_kv_indptr_cpu: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_last_page_len_cpu: torch.Tensor,
prefill_start: int,
page_size: int,
num_qo_heads: int,
dcp_world_size: int,
@ -200,7 +201,7 @@ class BatchDCPPrefillWrapper:
qo_indptr_cpu,
paged_kv_indptr_cpu,
paged_kv_indices,
paged_kv_last_page_len_cpu[prefill_start:],
paged_kv_last_page_len_cpu,
num_qo_heads * dcp_world_size,
num_kv_heads,
head_dim,
@ -380,40 +381,103 @@ class FlashInferBackend(AttentionBackend):
@dataclass
class FlashInferMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
class FIPrefill:
"""Metadata for the native FlashInfer prefill pathway (non-TRTLLM)."""
# The data type of the query
q_data_type: torch.dtype
wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper
@dataclass
class FIDecode:
"""Metadata for the native FlashInfer decode pathway (non-TRTLLM)."""
wrapper: BatchDecodeWithPagedKVCacheWrapper
@dataclass
class TRTLLMPrefill:
"""Metadata for the TRTLLM prefill pathway."""
block_tables: torch.Tensor
"""
The slice of the block table tensor corresponding *only* to prefill requests.
Shape: [num_prefills, max_num_blocks_per_seq]
"""
seq_lens: torch.Tensor
"""
The slice of the sequence lengths tensor corresponding *only* to prefill requests.
Shape: [num_prefills]
"""
cum_seq_lens_q: torch.Tensor
cum_seq_lens_kv: torch.Tensor
max_q_len: int
"""
The maximum query length *among prefill requests*.
"""
max_seq_len: int
"""The maximum sequence length for KV Cache."""
@dataclass
class TRTLLMDecode:
"""Metadata for the TRTLLM decode pathway."""
block_tables: torch.Tensor
"""
The slice of the block table tensor corresponding *only* to decode requests.
Shape: [num_decodes, max_num_blocks_per_seq]
"""
seq_lens: torch.Tensor
"""
The slice of the sequence lengths tensor corresponding *only* to decode requests.
Shape: [num_decodes]
"""
max_seq_len: int
"""The maximum sequence length for KV Cache."""
@dataclass
class FlashInferMetadata:
num_actual_tokens: int
"""Total number of tokens in the batch (excluding padding)."""
slot_mapping: torch.Tensor
"""Tensor for writing K/V to the cache. Shape: [num_actual_tokens]"""
# For flashinfer trtllm batch decode
max_q_len: int
max_q_len_prefill: int
max_seq_len: int
seq_lens: torch.Tensor
block_table_tensor: torch.Tensor
prefill_use_trtllm: bool
decode_use_trtllm: bool
q_data_type: torch.dtype
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
# For cascade attention (CPU for planning).
prefill: FIPrefill | TRTLLMPrefill | None
"""
Holds the metadata for the prefill portion of the batch.
Will be `None` if `num_prefill_tokens == 0`.
"""
decode: FIDecode | TRTLLMDecode | None
"""
Holds the metadata for the decode portion of the batch.
Will be `None` if `num_decode_tokens == 0`.
"""
# --- Special Case: Cascade Attention ---
use_cascade: bool
"""
If True, the entire batch is a cascade attention call, and the
`prefill` and `decode` fields will both be None.
"""
prefill_wrapper: (
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
) = None
decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
qo_indptr_gpu: torch.Tensor | None = None
paged_kv_indptr_gpu: torch.Tensor | None = None
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
@ -482,6 +546,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = 1
self.use_dcp = self.dcp_world_size > 1
self.num_qo_heads = self.model_config.get_num_attention_heads(
self.vllm_config.parallel_config
@ -535,34 +600,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"sinks, please use trtllm on blackwell or flash attention on "
"earlier GPUs."
)
# Preparing persistent buffers (device-side)
self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=self.device
)
self.paged_kv_indices = torch.zeros(
max_num_pages, # max num pages possible
dtype=torch.int32,
device=self.device,
)
self.paged_kv_last_page_len = torch.zeros(
max_num_reqs, dtype=torch.int32, device=self.device
)
# host-side buffer
pin_memory = is_pin_memory_available()
self.paged_kv_indptr_cpu = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
self.paged_kv_indptr_buffer = torch.zeros_like(
self.paged_kv_indptr_cpu, pin_memory=pin_memory
)
self.paged_kv_indices_cpu = torch.zeros(
max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.paged_kv_last_page_len_cpu = torch.zeros(
max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
# Preparing persistent buffers
self.pin_memory = is_pin_memory_available()
self.paged_kv_indptr = self._make_buffer(max_num_reqs + 1)
self.paged_kv_indptr_cpu_buffer = torch.zeros_like(
self.paged_kv_indptr.cpu, pin_memory=self.pin_memory
) # Extra buffer for mutable paged_kv_indptr.cpu in cuda graph mode
self.paged_kv_indices = self._make_buffer(max_num_pages)
self.paged_kv_last_page_len = self._make_buffer(max_num_reqs)
if self.head_dim == 256 and current_platform.is_device_capability_family(100):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
@ -573,6 +618,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"passing --block-size 32 or --block-size 64."
)
def _make_buffer(
self, *size: int | torch.SymInt, dtype: torch.dtype = torch.int32
) -> CpuGpuBuffer:
return CpuGpuBuffer(
*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory,
with_numpy=True,
)
@override # type: ignore[misc]
@classmethod
def get_cudagraph_support(
cls: type["FlashInferMetadataBuilder"],
@ -607,7 +664,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self,
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
if self._prefill_wrapper is None:
if self.dcp_world_size > 1:
if self.use_dcp:
self._prefill_wrapper = BatchDCPPrefillWrapper(
workspace_buffer=self._get_workspace_buffer(),
)
@ -626,9 +683,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if decode_wrapper is None:
if use_cudagraph:
paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1]
paged_kv_indices = self.paged_kv_indices
paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size]
paged_kv_indptr = self.paged_kv_indptr.gpu[: batch_size + 1]
paged_kv_indices = self.paged_kv_indices.gpu
paged_kv_last_page_len = self.paged_kv_last_page_len.gpu[:batch_size]
else:
paged_kv_indptr = None
paged_kv_indices = None
@ -661,6 +718,60 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
return self._cascade_wrapper
def _compute_flashinfer_kv_metadata(
self,
num_blocks_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table_tensor: torch.Tensor,
num_reqs: int,
page_size: int,
) -> torch.Tensor:
"""
Compute paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len for FlashInfer
attention.
Results are stored in self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len buffers.
Returns paged_kv_indices, a GPU tensor with shape [num_actual_pages].
"""
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
np.cumsum(
num_blocks_np,
dtype=np.int32,
out=self.paged_kv_indptr.np[1 : num_reqs + 1],
)
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
# after this line (e.g., for cuda graphs), we need to copy the data to
# self.paged_kv_indptr_buffer to avoid race condition.
self.paged_kv_indptr_cpu_buffer[: num_reqs + 1] = self.paged_kv_indptr.cpu[
: num_reqs + 1
]
paged_kv_indptr = self.paged_kv_indptr.gpu[: num_reqs + 1]
paged_kv_indptr.copy_(
self.paged_kv_indptr_cpu_buffer[: num_reqs + 1], non_blocking=True
)
# write self.paged_kv_indices inplace
num_actual_pages = self.paged_kv_indptr.np[num_reqs]
paged_kv_indices = self.paged_kv_indices.gpu[:num_actual_pages]
_copy_page_indices_kernel[(num_reqs,)](
paged_kv_indices,
block_table_tensor,
block_table_tensor.stride(0),
paged_kv_indptr,
BLOCK_SIZE=1024,
)
# write self.paged_kv_last_page_len_cpu inplace
paged_kv_last_page_len_np = seq_lens_np % page_size
self.paged_kv_last_page_len.np[:num_reqs] = np.where(
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
page_size,
paged_kv_last_page_len_np,
)
return paged_kv_indices
def build(
self,
common_prefix_len: int,
@ -678,98 +789,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
page_size = self.page_size
max_q_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
qo_indptr = common_attn_metadata.query_start_loc
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
if self.dcp_world_size > 1:
if num_prefills > 0:
qo_indptr_prefill_cpu = (
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
)
query_lens_prefill_cpu = (
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
)
seq_lens_cpu[num_decodes:] = (
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
)
seq_lens_cpu = get_dcp_local_seq_lens(
seq_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
)
seq_lens_np = seq_lens_cpu.numpy()
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
# Step 1: Decide which dispatch modes to use:
# - Cascade attention (distinct mode)
# - Prefill (FI native or TRTLLM)
# - Decode (FI native or TRTLLM)
use_cascade = common_prefix_len > 0
if use_cascade:
# Grab the blocks of the shared prefix from the first request.
assert common_prefix_len % page_size == 0
num_common_kv_blocks = common_prefix_len // page_size
# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device="cpu"
)
shared_kv_page_indptr_cpu = torch.tensor(
[0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
)
shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
shared_kv_last_page_len_cpu = torch.tensor(
[page_size], dtype=torch.int32, device="cpu"
)
# Remove the blocks of the shared prefix from all requests.
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
num_blocks_np -= num_common_kv_blocks
else:
shared_qo_indptr_cpu = None
shared_kv_page_indptr_cpu = None
shared_kv_page_indices_cpu = None
shared_kv_last_page_len_cpu = None
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
np.cumsum(
num_blocks_np,
dtype=np.int32,
out=self.paged_kv_indptr_np[1 : num_reqs + 1],
)
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
# after this line (e.g., for cuda graphs), we need to copy the data to
# self.paged_kv_indptr_buffer to avoid race condition.
self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[
: num_reqs + 1
]
paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1]
paged_kv_indptr.copy_(
self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True
)
# write self.paged_kv_indices inplace
num_actual_pages = self.paged_kv_indptr_np[num_reqs]
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
_copy_page_indices_kernel[(num_reqs,)](
paged_kv_indices,
block_table_tensor,
block_table_tensor.stride(0),
paged_kv_indptr,
BLOCK_SIZE=1024,
)
# write self.paged_kv_last_page_len_cpu inplace
paged_kv_last_page_len_np = seq_lens_np % page_size
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
page_size,
paged_kv_last_page_len_np,
)
uses_spec_reorder = self.reorder_batch_threshold > 1
prefill_use_trtllm = use_trtllm_attention(
self.num_qo_heads,
@ -788,7 +818,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
)
if not (prefill_use_trtllm and decode_use_trtllm):
all_uses_trtllm = (num_prefills == 0 or prefill_use_trtllm) and (
num_decodes == 0 or decode_use_trtllm
)
is_only_trtllm_decode = num_prefills == 0 and (
num_decodes > 0 and decode_use_trtllm
)
if not all_uses_trtllm:
if self.has_sinks:
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
@ -813,28 +850,102 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# fall back to model dtype.
self.q_data_type = self.model_config.dtype
# Step 2: Initialize the output metadata
# Leave prefill/decode/cascade_wrapper empty, to be populated
# case by case depending on the batch contents and backend selection.
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
q_data_type=self.q_data_type,
slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len,
max_q_len_prefill=max_q_len,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
prefill_use_trtllm=prefill_use_trtllm,
decode_use_trtllm=decode_use_trtllm,
q_data_type=self.q_data_type,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
use_cascade=use_cascade,
prefill=None,
decode=None,
cascade_wrapper=None,
)
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs]
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
# Guard access to seq_lens_cpu, which may not always be needed
# and can be expensive to retrieve in async mode.
needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
num_blocks_np = (
(seq_lens_np + (page_size - 1)) // page_size
if seq_lens_np is not None
else None
)
# Adjust seq_lens_cpu for DCP
if self.use_dcp:
assert seq_lens_cpu is not None
if num_prefills > 0:
qo_indptr_prefill_cpu = (
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
)
query_lens_prefill_cpu = (
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
)
seq_lens_cpu[num_decodes:] = (
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
)
seq_lens_cpu = get_dcp_local_seq_lens(
seq_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
)
# Adjust num_block_np for cascade attention
if use_cascade:
assert num_blocks_np is not None
assert common_prefix_len % page_size == 0
num_common_kv_blocks = common_prefix_len // page_size
num_blocks_np -= num_common_kv_blocks
# Compute paged_kv_indices if necessary
needs_paged_kv_indices = use_cascade or not is_only_trtllm_decode
if needs_paged_kv_indices:
assert num_blocks_np is not None
assert seq_lens_np is not None
paged_kv_indices = self._compute_flashinfer_kv_metadata(
num_blocks_np,
seq_lens_np,
block_table_tensor,
num_reqs,
page_size,
)
else:
paged_kv_indices = None
# Early-out for cascade attention
if use_cascade:
# Grab the blocks of the shared prefix from the first request.
num_common_kv_blocks = common_prefix_len // page_size
# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device="cpu"
)
shared_kv_page_indptr_cpu = torch.tensor(
[0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
)
shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
shared_kv_last_page_len_cpu = torch.tensor(
[page_size], dtype=torch.int32, device="cpu"
)
# Remove the blocks of the shared prefix from all requests.
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
num_blocks_np -= num_common_kv_blocks
assert paged_kv_indices is not None
paged_kv_indptr_cpu = self.paged_kv_indptr.cpu[: 1 + num_reqs]
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len.cpu[:num_reqs]
if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(
[shared_qo_indptr_cpu, qo_indptr_cpu],
@ -852,91 +963,107 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
else:
# Regular attention (common case).
# Decodes are at the front and prefills are at the back.
num_prefills = attn_metadata.num_prefills
num_decodes = attn_metadata.num_decodes
if num_prefills > 0:
# Decodes are first so prefills start after the last decode
prefill_start = num_decodes
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
assert (
paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills
)
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
qo_indptr_cpu = (
qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
)
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
return attn_metadata
# Recompute max_q_len for the slice of requests we are using
# for prefills. This can be different from max_q_len when
# we have a non-uniform batch with some short decodes offloaded
# to the prefill pathway
query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item())
# Step 3: Handle prefill and decode pathways case by case
## PREFILL PATHWAY
if num_prefills > 0:
# Slices for shared prefill metadata
prefill_start = num_decodes
qo_indptr_prefill_cpu = (
qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
)
assert qo_indptr_prefill_cpu.shape[0] == num_prefills + 1
if not attn_metadata.prefill_use_trtllm:
if self.dcp_world_size > 1:
assert isinstance(
attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper
)
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu=qo_indptr_cpu,
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
prefill_start=prefill_start,
page_size=self.page_size,
num_qo_heads=self.num_qo_heads,
dcp_world_size=self.dcp_world_size,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_cache_dtype=self.kv_cache_dtype,
prefill_fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
else:
assert isinstance(
attn_metadata.prefill_wrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu,
paged_kv_indptr_cpu,
paged_kv_indices,
paged_kv_last_page_len_cpu[prefill_start:],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
if prefill_use_trtllm:
# Create GPU versions
qo_indptr_prefill_gpu = (
qo_indptr[prefill_start:] - qo_indptr[prefill_start]
)
paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[
prefill_start : num_reqs + 1
]
# Compute max_q_len for prefill requests
query_lens_prefill_cpu = (
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
)
max_q_len_prefill = int(query_lens_prefill_cpu.max().item())
attn_metadata.prefill = TRTLLMPrefill(
block_tables=block_table_tensor[prefill_start:],
seq_lens=seq_lens[prefill_start:],
cum_seq_lens_q=qo_indptr_prefill_gpu,
cum_seq_lens_kv=paged_kv_indptr_prefill_gpu,
max_q_len=max_q_len_prefill,
max_seq_len=max_seq_len,
)
else:
prefill_wrapper = self._get_prefill_wrapper()
# Slicing CPU buffers that are only needed for FI native prefills
paged_kv_last_page_len_prefill_cpu = self.paged_kv_last_page_len.cpu[
prefill_start:num_reqs
]
assert paged_kv_last_page_len_prefill_cpu.shape[0] == num_prefills
paged_kv_indptr_prefill_cpu = self.paged_kv_indptr.cpu[
prefill_start : num_reqs + 1
]
assert paged_kv_indptr_prefill_cpu.shape[0] == num_prefills + 1
if self.use_dcp:
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
prefill_wrapper.plan(
qo_indptr_cpu=qo_indptr_prefill_cpu,
paged_kv_indptr_cpu=paged_kv_indptr_prefill_cpu,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len_cpu=paged_kv_last_page_len_prefill_cpu,
page_size=self.page_size,
num_qo_heads=self.num_qo_heads,
dcp_world_size=self.dcp_world_size,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_cache_dtype=self.kv_cache_dtype,
prefill_fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
self.device, non_blocking=True
assert isinstance(
prefill_wrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
self.device, non_blocking=True
prefill_wrapper.plan(
qo_indptr_prefill_cpu,
paged_kv_indptr_prefill_cpu,
paged_kv_indices,
paged_kv_last_page_len_prefill_cpu,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
attn_metadata.prefill = FIPrefill(wrapper=prefill_wrapper)
if num_decodes > 0:
## DECODE PATHWAY
if num_decodes > 0:
if decode_use_trtllm:
assert num_decode_tokens % num_decodes == 0, (
"TRTLLM decode requires uniform query lengths per request."
)
attn_metadata.decode = TRTLLMDecode(
block_tables=block_table_tensor[:num_decodes],
seq_lens=seq_lens[:num_decodes],
max_seq_len=max_seq_len,
)
else:
pure_decode = num_prefills == 0
use_cudagraph = (
self.enable_cuda_graph
@ -945,33 +1072,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
num_input_tokens = num_decode_tokens
attn_metadata.decode_wrapper = self._get_decode_wrapper(
decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph
)
if not attn_metadata.decode_use_trtllm:
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# in atten_metadata when using cudagraph.
fast_plan_decode(
attn_metadata.decode_wrapper,
self.paged_kv_indptr_cpu[: num_input_tokens + 1],
paged_kv_indices,
self.paged_kv_last_page_len_cpu[:num_input_tokens],
seq_lens_cpu[:num_input_tokens],
self.num_qo_heads * self.dcp_world_size,
self.num_kv_heads,
self.head_dim,
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.decode_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# in atten_metadata when using cudagraph.
fast_plan_decode(
decode_wrapper,
self.paged_kv_indptr.cpu[: num_input_tokens + 1],
paged_kv_indices,
self.paged_kv_last_page_len.cpu[:num_input_tokens],
seq_lens_cpu[:num_input_tokens],
self.num_qo_heads * self.dcp_world_size,
self.num_kv_heads,
self.head_dim,
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.decode_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
attn_metadata.decode = FIDecode(wrapper=decode_wrapper)
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
@ -1104,6 +1231,9 @@ class FlashInferImpl(AttentionImpl):
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode)
# The attn+quant fusion happens when output_scale is provided.
if output_scale is None:
assert output_block_scale is None, (
@ -1113,8 +1243,8 @@ class FlashInferImpl(AttentionImpl):
assert attn_metadata.q_data_type == FP8_DTYPE, (
"Query must be FP8 when attn+quant fusion happened."
)
assert (
attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm
assert (attn_metadata.num_prefills == 0 or prefill_use_trtllm) and (
attn_metadata.num_decodes == 0 or decode_use_trtllm
), "Must use TRT-LLM attn"
if output.dtype == FP8_DTYPE:
@ -1191,22 +1321,25 @@ class FlashInferImpl(AttentionImpl):
# When using spec decoding, num_decodes can be < num_decode_tokens
# because some decode requests may have more than one query token.
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens
stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order)
use_dcp = self.dcp_world_size > 1
# Regular attention (common case).
# Decodes are at the front and prefills are at the back.
if num_prefill_tokens > 0:
prefill_wrapper = attn_metadata.prefill_wrapper
prefill_query = query[num_decode_tokens:]
assert prefill_query.shape[0] == num_prefill_tokens
assert prefill_wrapper is not None
if not attn_metadata.prefill_use_trtllm:
if self.dcp_world_size > 1:
if not prefill_use_trtllm:
assert isinstance(attn_metadata.prefill, FIPrefill)
prefill_wrapper = attn_metadata.prefill.wrapper
assert prefill_wrapper is not None
if use_dcp:
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
assert prefill_wrapper._context._window_left == self.window_left
assert prefill_wrapper._context._logits_soft_cap == (
@ -1247,11 +1380,12 @@ class FlashInferImpl(AttentionImpl):
out=output[num_decode_tokens:],
)
else:
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
# prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous()
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:]
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
block_tables_prefill = attn_metadata.prefill.block_tables
seq_lens_prefill = attn_metadata.prefill.seq_lens
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
@ -1298,13 +1432,13 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer=workspace_buffer,
block_tables=mock_block_table,
seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len_prefill,
max_kv_len=attn_metadata.max_seq_len,
max_q_len=attn_metadata.prefill.max_q_len,
max_kv_len=attn_metadata.prefill.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
batch_size=attn_metadata.num_prefills,
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
cum_seq_lens_q=attn_metadata.prefill.cum_seq_lens_q,
cum_seq_lens_kv=attn_metadata.prefill.cum_seq_lens_kv,
window_left=self.window_left,
sinks=self.sinks,
o_sf_scale=self.o_sf_scale,
@ -1312,17 +1446,18 @@ class FlashInferImpl(AttentionImpl):
)
if num_decode_tokens > 0:
decode_wrapper = attn_metadata.decode_wrapper
decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None
if not attn_metadata.decode_use_trtllm:
if not decode_use_trtllm:
assert isinstance(attn_metadata.decode, FIDecode)
decode_wrapper = attn_metadata.decode.wrapper
assert decode_wrapper is not None
assert decode_wrapper._window_left == self.window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
assert decode_wrapper._sm_scale == self.scale
if self.dcp_world_size > 1:
if use_dcp:
decode_query = get_dcp_group().all_gather(
decode_query.contiguous(), dim=-2
)
@ -1357,12 +1492,11 @@ class FlashInferImpl(AttentionImpl):
)
else:
# decode_query may be non-contiguous
assert isinstance(attn_metadata.decode, TRTLLMDecode)
decode_query = decode_query.contiguous()
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_decode = attn_metadata.block_table_tensor[
:num_decode_tokens
]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
block_tables_decode = attn_metadata.decode.block_tables
seq_lens_decode = attn_metadata.decode.seq_lens
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
@ -1397,7 +1531,7 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer=workspace_buffer,
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len,
max_seq_len=attn_metadata.decode.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
window_left=self.window_left,