mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 22:57:32 +08:00
[Cleanup] Refactor FlashInferMetadataBuilder (#29128)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
6ca74bc11a
commit
d6b3d39b6d
@ -16,6 +16,7 @@ from flashinfer import (
|
||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
from flashinfer.utils import FP4Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
@ -59,6 +60,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
|
||||
|
||||
@ -181,7 +183,6 @@ class BatchDCPPrefillWrapper:
|
||||
paged_kv_indptr_cpu: torch.Tensor,
|
||||
paged_kv_indices: torch.Tensor,
|
||||
paged_kv_last_page_len_cpu: torch.Tensor,
|
||||
prefill_start: int,
|
||||
page_size: int,
|
||||
num_qo_heads: int,
|
||||
dcp_world_size: int,
|
||||
@ -200,7 +201,7 @@ class BatchDCPPrefillWrapper:
|
||||
qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu[prefill_start:],
|
||||
paged_kv_last_page_len_cpu,
|
||||
num_qo_heads * dcp_world_size,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
@ -380,40 +381,103 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
class FIPrefill:
|
||||
"""Metadata for the native FlashInfer prefill pathway (non-TRTLLM)."""
|
||||
|
||||
# The data type of the query
|
||||
q_data_type: torch.dtype
|
||||
wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class FIDecode:
|
||||
"""Metadata for the native FlashInfer decode pathway (non-TRTLLM)."""
|
||||
|
||||
wrapper: BatchDecodeWithPagedKVCacheWrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class TRTLLMPrefill:
|
||||
"""Metadata for the TRTLLM prefill pathway."""
|
||||
|
||||
block_tables: torch.Tensor
|
||||
"""
|
||||
The slice of the block table tensor corresponding *only* to prefill requests.
|
||||
Shape: [num_prefills, max_num_blocks_per_seq]
|
||||
"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
"""
|
||||
The slice of the sequence lengths tensor corresponding *only* to prefill requests.
|
||||
Shape: [num_prefills]
|
||||
"""
|
||||
|
||||
cum_seq_lens_q: torch.Tensor
|
||||
cum_seq_lens_kv: torch.Tensor
|
||||
|
||||
max_q_len: int
|
||||
"""
|
||||
The maximum query length *among prefill requests*.
|
||||
"""
|
||||
|
||||
max_seq_len: int
|
||||
"""The maximum sequence length for KV Cache."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TRTLLMDecode:
|
||||
"""Metadata for the TRTLLM decode pathway."""
|
||||
|
||||
block_tables: torch.Tensor
|
||||
"""
|
||||
The slice of the block table tensor corresponding *only* to decode requests.
|
||||
Shape: [num_decodes, max_num_blocks_per_seq]
|
||||
"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
"""
|
||||
The slice of the sequence lengths tensor corresponding *only* to decode requests.
|
||||
Shape: [num_decodes]
|
||||
"""
|
||||
|
||||
max_seq_len: int
|
||||
"""The maximum sequence length for KV Cache."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata:
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in the batch (excluding padding)."""
|
||||
|
||||
slot_mapping: torch.Tensor
|
||||
"""Tensor for writing K/V to the cache. Shape: [num_actual_tokens]"""
|
||||
|
||||
# For flashinfer trtllm batch decode
|
||||
max_q_len: int
|
||||
max_q_len_prefill: int
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table_tensor: torch.Tensor
|
||||
prefill_use_trtllm: bool
|
||||
decode_use_trtllm: bool
|
||||
q_data_type: torch.dtype
|
||||
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
# For cascade attention (CPU for planning).
|
||||
prefill: FIPrefill | TRTLLMPrefill | None
|
||||
"""
|
||||
Holds the metadata for the prefill portion of the batch.
|
||||
Will be `None` if `num_prefill_tokens == 0`.
|
||||
"""
|
||||
|
||||
decode: FIDecode | TRTLLMDecode | None
|
||||
"""
|
||||
Holds the metadata for the decode portion of the batch.
|
||||
Will be `None` if `num_decode_tokens == 0`.
|
||||
"""
|
||||
|
||||
# --- Special Case: Cascade Attention ---
|
||||
|
||||
use_cascade: bool
|
||||
"""
|
||||
If True, the entire batch is a cascade attention call, and the
|
||||
`prefill` and `decode` fields will both be None.
|
||||
"""
|
||||
|
||||
prefill_wrapper: (
|
||||
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
|
||||
) = None
|
||||
decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
|
||||
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
|
||||
|
||||
qo_indptr_gpu: torch.Tensor | None = None
|
||||
paged_kv_indptr_gpu: torch.Tensor | None = None
|
||||
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
@ -482,6 +546,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.dcp_kv_cache_interleave_size = 1
|
||||
self.use_dcp = self.dcp_world_size > 1
|
||||
|
||||
self.num_qo_heads = self.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config
|
||||
@ -535,34 +600,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
"sinks, please use trtllm on blackwell or flash attention on "
|
||||
"earlier GPUs."
|
||||
)
|
||||
# Preparing persistent buffers (device-side)
|
||||
self.paged_kv_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
max_num_pages, # max num pages possible
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.paged_kv_last_page_len = torch.zeros(
|
||||
max_num_reqs, dtype=torch.int32, device=self.device
|
||||
)
|
||||
# host-side buffer
|
||||
pin_memory = is_pin_memory_available()
|
||||
self.paged_kv_indptr_cpu = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
|
||||
self.paged_kv_indptr_buffer = torch.zeros_like(
|
||||
self.paged_kv_indptr_cpu, pin_memory=pin_memory
|
||||
)
|
||||
self.paged_kv_indices_cpu = torch.zeros(
|
||||
max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.paged_kv_last_page_len_cpu = torch.zeros(
|
||||
max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
|
||||
# Preparing persistent buffers
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.paged_kv_indptr = self._make_buffer(max_num_reqs + 1)
|
||||
self.paged_kv_indptr_cpu_buffer = torch.zeros_like(
|
||||
self.paged_kv_indptr.cpu, pin_memory=self.pin_memory
|
||||
) # Extra buffer for mutable paged_kv_indptr.cpu in cuda graph mode
|
||||
self.paged_kv_indices = self._make_buffer(max_num_pages)
|
||||
self.paged_kv_last_page_len = self._make_buffer(max_num_reqs)
|
||||
|
||||
if self.head_dim == 256 and current_platform.is_device_capability_family(100):
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
|
||||
@ -573,6 +618,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
"passing --block-size 32 or --block-size 64."
|
||||
)
|
||||
|
||||
def _make_buffer(
|
||||
self, *size: int | torch.SymInt, dtype: torch.dtype = torch.int32
|
||||
) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(
|
||||
*size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
with_numpy=True,
|
||||
)
|
||||
|
||||
@override # type: ignore[misc]
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["FlashInferMetadataBuilder"],
|
||||
@ -607,7 +664,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self,
|
||||
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
|
||||
if self._prefill_wrapper is None:
|
||||
if self.dcp_world_size > 1:
|
||||
if self.use_dcp:
|
||||
self._prefill_wrapper = BatchDCPPrefillWrapper(
|
||||
workspace_buffer=self._get_workspace_buffer(),
|
||||
)
|
||||
@ -626,9 +683,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
if decode_wrapper is None:
|
||||
if use_cudagraph:
|
||||
paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1]
|
||||
paged_kv_indices = self.paged_kv_indices
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size]
|
||||
paged_kv_indptr = self.paged_kv_indptr.gpu[: batch_size + 1]
|
||||
paged_kv_indices = self.paged_kv_indices.gpu
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len.gpu[:batch_size]
|
||||
else:
|
||||
paged_kv_indptr = None
|
||||
paged_kv_indices = None
|
||||
@ -661,6 +718,60 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
return self._cascade_wrapper
|
||||
|
||||
def _compute_flashinfer_kv_metadata(
|
||||
self,
|
||||
num_blocks_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table_tensor: torch.Tensor,
|
||||
num_reqs: int,
|
||||
page_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len for FlashInfer
|
||||
attention.
|
||||
|
||||
Results are stored in self.paged_kv_indptr,
|
||||
self.paged_kv_indices, self.paged_kv_last_page_len buffers.
|
||||
|
||||
Returns paged_kv_indices, a GPU tensor with shape [num_actual_pages].
|
||||
"""
|
||||
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
|
||||
np.cumsum(
|
||||
num_blocks_np,
|
||||
dtype=np.int32,
|
||||
out=self.paged_kv_indptr.np[1 : num_reqs + 1],
|
||||
)
|
||||
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
|
||||
# after this line (e.g., for cuda graphs), we need to copy the data to
|
||||
# self.paged_kv_indptr_buffer to avoid race condition.
|
||||
self.paged_kv_indptr_cpu_buffer[: num_reqs + 1] = self.paged_kv_indptr.cpu[
|
||||
: num_reqs + 1
|
||||
]
|
||||
paged_kv_indptr = self.paged_kv_indptr.gpu[: num_reqs + 1]
|
||||
paged_kv_indptr.copy_(
|
||||
self.paged_kv_indptr_cpu_buffer[: num_reqs + 1], non_blocking=True
|
||||
)
|
||||
|
||||
# write self.paged_kv_indices inplace
|
||||
num_actual_pages = self.paged_kv_indptr.np[num_reqs]
|
||||
paged_kv_indices = self.paged_kv_indices.gpu[:num_actual_pages]
|
||||
_copy_page_indices_kernel[(num_reqs,)](
|
||||
paged_kv_indices,
|
||||
block_table_tensor,
|
||||
block_table_tensor.stride(0),
|
||||
paged_kv_indptr,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
# write self.paged_kv_last_page_len_cpu inplace
|
||||
paged_kv_last_page_len_np = seq_lens_np % page_size
|
||||
self.paged_kv_last_page_len.np[:num_reqs] = np.where(
|
||||
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
|
||||
page_size,
|
||||
paged_kv_last_page_len_np,
|
||||
)
|
||||
return paged_kv_indices
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
@ -678,98 +789,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
|
||||
page_size = self.page_size
|
||||
max_q_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
qo_indptr = common_attn_metadata.query_start_loc
|
||||
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
if num_prefills > 0:
|
||||
qo_indptr_prefill_cpu = (
|
||||
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
|
||||
)
|
||||
query_lens_prefill_cpu = (
|
||||
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
|
||||
)
|
||||
seq_lens_cpu[num_decodes:] = (
|
||||
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
|
||||
)
|
||||
|
||||
seq_lens_cpu = get_dcp_local_seq_lens(
|
||||
seq_lens_cpu,
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.dcp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
seq_lens_np = seq_lens_cpu.numpy()
|
||||
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
|
||||
|
||||
# Step 1: Decide which dispatch modes to use:
|
||||
# - Cascade attention (distinct mode)
|
||||
# - Prefill (FI native or TRTLLM)
|
||||
# - Decode (FI native or TRTLLM)
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
# Grab the blocks of the shared prefix from the first request.
|
||||
assert common_prefix_len % page_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // page_size
|
||||
|
||||
# Create CPU versions directly for cascade (no GPU versions needed)
|
||||
shared_qo_indptr_cpu = torch.tensor(
|
||||
[0, num_actual_tokens], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
shared_kv_page_indptr_cpu = torch.tensor(
|
||||
[0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
|
||||
shared_kv_last_page_len_cpu = torch.tensor(
|
||||
[page_size], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
|
||||
# Remove the blocks of the shared prefix from all requests.
|
||||
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
||||
num_blocks_np -= num_common_kv_blocks
|
||||
else:
|
||||
shared_qo_indptr_cpu = None
|
||||
shared_kv_page_indptr_cpu = None
|
||||
shared_kv_page_indices_cpu = None
|
||||
shared_kv_last_page_len_cpu = None
|
||||
|
||||
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
|
||||
np.cumsum(
|
||||
num_blocks_np,
|
||||
dtype=np.int32,
|
||||
out=self.paged_kv_indptr_np[1 : num_reqs + 1],
|
||||
)
|
||||
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
|
||||
# after this line (e.g., for cuda graphs), we need to copy the data to
|
||||
# self.paged_kv_indptr_buffer to avoid race condition.
|
||||
self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[
|
||||
: num_reqs + 1
|
||||
]
|
||||
paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1]
|
||||
paged_kv_indptr.copy_(
|
||||
self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True
|
||||
)
|
||||
|
||||
# write self.paged_kv_indices inplace
|
||||
num_actual_pages = self.paged_kv_indptr_np[num_reqs]
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
_copy_page_indices_kernel[(num_reqs,)](
|
||||
paged_kv_indices,
|
||||
block_table_tensor,
|
||||
block_table_tensor.stride(0),
|
||||
paged_kv_indptr,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
# write self.paged_kv_last_page_len_cpu inplace
|
||||
paged_kv_last_page_len_np = seq_lens_np % page_size
|
||||
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
|
||||
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
|
||||
page_size,
|
||||
paged_kv_last_page_len_np,
|
||||
)
|
||||
|
||||
uses_spec_reorder = self.reorder_batch_threshold > 1
|
||||
prefill_use_trtllm = use_trtllm_attention(
|
||||
self.num_qo_heads,
|
||||
@ -788,7 +818,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
|
||||
)
|
||||
|
||||
if not (prefill_use_trtllm and decode_use_trtllm):
|
||||
all_uses_trtllm = (num_prefills == 0 or prefill_use_trtllm) and (
|
||||
num_decodes == 0 or decode_use_trtllm
|
||||
)
|
||||
is_only_trtllm_decode = num_prefills == 0 and (
|
||||
num_decodes > 0 and decode_use_trtllm
|
||||
)
|
||||
|
||||
if not all_uses_trtllm:
|
||||
if self.has_sinks:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer backend currently does not support attention "
|
||||
@ -813,28 +850,102 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# fall back to model dtype.
|
||||
self.q_data_type = self.model_config.dtype
|
||||
|
||||
# Step 2: Initialize the output metadata
|
||||
# Leave prefill/decode/cascade_wrapper empty, to be populated
|
||||
# case by case depending on the batch contents and backend selection.
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
q_data_type=self.q_data_type,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
max_q_len=max_q_len,
|
||||
max_q_len_prefill=max_q_len,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table_tensor=block_table_tensor,
|
||||
prefill_use_trtllm=prefill_use_trtllm,
|
||||
decode_use_trtllm=decode_use_trtllm,
|
||||
q_data_type=self.q_data_type,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
use_cascade=use_cascade,
|
||||
prefill=None,
|
||||
decode=None,
|
||||
cascade_wrapper=None,
|
||||
)
|
||||
|
||||
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs]
|
||||
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
|
||||
# Guard access to seq_lens_cpu, which may not always be needed
|
||||
# and can be expensive to retrieve in async mode.
|
||||
needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
|
||||
seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
|
||||
num_blocks_np = (
|
||||
(seq_lens_np + (page_size - 1)) // page_size
|
||||
if seq_lens_np is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Adjust seq_lens_cpu for DCP
|
||||
if self.use_dcp:
|
||||
assert seq_lens_cpu is not None
|
||||
if num_prefills > 0:
|
||||
qo_indptr_prefill_cpu = (
|
||||
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
|
||||
)
|
||||
query_lens_prefill_cpu = (
|
||||
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
|
||||
)
|
||||
seq_lens_cpu[num_decodes:] = (
|
||||
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
|
||||
)
|
||||
|
||||
seq_lens_cpu = get_dcp_local_seq_lens(
|
||||
seq_lens_cpu,
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.dcp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
# Adjust num_block_np for cascade attention
|
||||
if use_cascade:
|
||||
assert num_blocks_np is not None
|
||||
assert common_prefix_len % page_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // page_size
|
||||
num_blocks_np -= num_common_kv_blocks
|
||||
|
||||
# Compute paged_kv_indices if necessary
|
||||
needs_paged_kv_indices = use_cascade or not is_only_trtllm_decode
|
||||
if needs_paged_kv_indices:
|
||||
assert num_blocks_np is not None
|
||||
assert seq_lens_np is not None
|
||||
paged_kv_indices = self._compute_flashinfer_kv_metadata(
|
||||
num_blocks_np,
|
||||
seq_lens_np,
|
||||
block_table_tensor,
|
||||
num_reqs,
|
||||
page_size,
|
||||
)
|
||||
else:
|
||||
paged_kv_indices = None
|
||||
|
||||
# Early-out for cascade attention
|
||||
if use_cascade:
|
||||
# Grab the blocks of the shared prefix from the first request.
|
||||
num_common_kv_blocks = common_prefix_len // page_size
|
||||
|
||||
# Create CPU versions directly for cascade (no GPU versions needed)
|
||||
shared_qo_indptr_cpu = torch.tensor(
|
||||
[0, num_actual_tokens], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
shared_kv_page_indptr_cpu = torch.tensor(
|
||||
[0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
|
||||
shared_kv_last_page_len_cpu = torch.tensor(
|
||||
[page_size], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
|
||||
# Remove the blocks of the shared prefix from all requests.
|
||||
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
||||
num_blocks_np -= num_common_kv_blocks
|
||||
|
||||
assert paged_kv_indices is not None
|
||||
paged_kv_indptr_cpu = self.paged_kv_indptr.cpu[: 1 + num_reqs]
|
||||
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len.cpu[:num_reqs]
|
||||
|
||||
if attn_metadata.use_cascade:
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
[shared_qo_indptr_cpu, qo_indptr_cpu],
|
||||
@ -852,91 +963,107 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back.
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
if num_prefills > 0:
|
||||
# Decodes are first so prefills start after the last decode
|
||||
prefill_start = num_decodes
|
||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
||||
assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
|
||||
assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
|
||||
assert (
|
||||
paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills
|
||||
)
|
||||
# Since prefill_wrapper.run() will be called with
|
||||
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
||||
# to be relative to the start of the prefill queries.
|
||||
qo_indptr_cpu = (
|
||||
qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
|
||||
)
|
||||
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
|
||||
return attn_metadata
|
||||
|
||||
# Recompute max_q_len for the slice of requests we are using
|
||||
# for prefills. This can be different from max_q_len when
|
||||
# we have a non-uniform batch with some short decodes offloaded
|
||||
# to the prefill pathway
|
||||
query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
|
||||
attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item())
|
||||
# Step 3: Handle prefill and decode pathways case by case
|
||||
## PREFILL PATHWAY
|
||||
if num_prefills > 0:
|
||||
# Slices for shared prefill metadata
|
||||
prefill_start = num_decodes
|
||||
qo_indptr_prefill_cpu = (
|
||||
qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
|
||||
)
|
||||
assert qo_indptr_prefill_cpu.shape[0] == num_prefills + 1
|
||||
|
||||
if not attn_metadata.prefill_use_trtllm:
|
||||
if self.dcp_world_size > 1:
|
||||
assert isinstance(
|
||||
attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper
|
||||
)
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr_cpu=qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
|
||||
prefill_start=prefill_start,
|
||||
page_size=self.page_size,
|
||||
num_qo_heads=self.num_qo_heads,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
prefill_fixed_split_size=self.prefill_fixed_split_size,
|
||||
disable_split_kv=self.disable_split_kv,
|
||||
)
|
||||
else:
|
||||
assert isinstance(
|
||||
attn_metadata.prefill_wrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu[prefill_start:],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
fixed_split_size=self.prefill_fixed_split_size,
|
||||
disable_split_kv=self.disable_split_kv,
|
||||
)
|
||||
if prefill_use_trtllm:
|
||||
# Create GPU versions
|
||||
qo_indptr_prefill_gpu = (
|
||||
qo_indptr[prefill_start:] - qo_indptr[prefill_start]
|
||||
)
|
||||
paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[
|
||||
prefill_start : num_reqs + 1
|
||||
]
|
||||
# Compute max_q_len for prefill requests
|
||||
query_lens_prefill_cpu = (
|
||||
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
|
||||
)
|
||||
max_q_len_prefill = int(query_lens_prefill_cpu.max().item())
|
||||
attn_metadata.prefill = TRTLLMPrefill(
|
||||
block_tables=block_table_tensor[prefill_start:],
|
||||
seq_lens=seq_lens[prefill_start:],
|
||||
cum_seq_lens_q=qo_indptr_prefill_gpu,
|
||||
cum_seq_lens_kv=paged_kv_indptr_prefill_gpu,
|
||||
max_q_len=max_q_len_prefill,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
else:
|
||||
prefill_wrapper = self._get_prefill_wrapper()
|
||||
# Slicing CPU buffers that are only needed for FI native prefills
|
||||
paged_kv_last_page_len_prefill_cpu = self.paged_kv_last_page_len.cpu[
|
||||
prefill_start:num_reqs
|
||||
]
|
||||
assert paged_kv_last_page_len_prefill_cpu.shape[0] == num_prefills
|
||||
paged_kv_indptr_prefill_cpu = self.paged_kv_indptr.cpu[
|
||||
prefill_start : num_reqs + 1
|
||||
]
|
||||
assert paged_kv_indptr_prefill_cpu.shape[0] == num_prefills + 1
|
||||
if self.use_dcp:
|
||||
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
|
||||
prefill_wrapper.plan(
|
||||
qo_indptr_cpu=qo_indptr_prefill_cpu,
|
||||
paged_kv_indptr_cpu=paged_kv_indptr_prefill_cpu,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu=paged_kv_last_page_len_prefill_cpu,
|
||||
page_size=self.page_size,
|
||||
num_qo_heads=self.num_qo_heads,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
prefill_fixed_split_size=self.prefill_fixed_split_size,
|
||||
disable_split_kv=self.disable_split_kv,
|
||||
)
|
||||
else:
|
||||
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
|
||||
self.device, non_blocking=True
|
||||
assert isinstance(
|
||||
prefill_wrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
|
||||
self.device, non_blocking=True
|
||||
prefill_wrapper.plan(
|
||||
qo_indptr_prefill_cpu,
|
||||
paged_kv_indptr_prefill_cpu,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len_prefill_cpu,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
fixed_split_size=self.prefill_fixed_split_size,
|
||||
disable_split_kv=self.disable_split_kv,
|
||||
)
|
||||
attn_metadata.prefill = FIPrefill(wrapper=prefill_wrapper)
|
||||
|
||||
if num_decodes > 0:
|
||||
## DECODE PATHWAY
|
||||
if num_decodes > 0:
|
||||
if decode_use_trtllm:
|
||||
assert num_decode_tokens % num_decodes == 0, (
|
||||
"TRTLLM decode requires uniform query lengths per request."
|
||||
)
|
||||
attn_metadata.decode = TRTLLMDecode(
|
||||
block_tables=block_table_tensor[:num_decodes],
|
||||
seq_lens=seq_lens[:num_decodes],
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
else:
|
||||
pure_decode = num_prefills == 0
|
||||
use_cudagraph = (
|
||||
self.enable_cuda_graph
|
||||
@ -945,33 +1072,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
num_input_tokens = num_decode_tokens
|
||||
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
||||
decode_wrapper = self._get_decode_wrapper(
|
||||
num_input_tokens, use_cudagraph
|
||||
)
|
||||
if not attn_metadata.decode_use_trtllm:
|
||||
# Use the persistent buffer with padding length,
|
||||
# instead of the same address but chunked version
|
||||
# in atten_metadata when using cudagraph.
|
||||
fast_plan_decode(
|
||||
attn_metadata.decode_wrapper,
|
||||
self.paged_kv_indptr_cpu[: num_input_tokens + 1],
|
||||
paged_kv_indices,
|
||||
self.paged_kv_last_page_len_cpu[:num_input_tokens],
|
||||
seq_lens_cpu[:num_input_tokens],
|
||||
self.num_qo_heads * self.dcp_world_size,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
fixed_split_size=self.decode_fixed_split_size,
|
||||
disable_split_kv=self.disable_split_kv,
|
||||
)
|
||||
# Use the persistent buffer with padding length,
|
||||
# instead of the same address but chunked version
|
||||
# in atten_metadata when using cudagraph.
|
||||
fast_plan_decode(
|
||||
decode_wrapper,
|
||||
self.paged_kv_indptr.cpu[: num_input_tokens + 1],
|
||||
paged_kv_indices,
|
||||
self.paged_kv_last_page_len.cpu[:num_input_tokens],
|
||||
seq_lens_cpu[:num_input_tokens],
|
||||
self.num_qo_heads * self.dcp_world_size,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
fixed_split_size=self.decode_fixed_split_size,
|
||||
disable_split_kv=self.disable_split_kv,
|
||||
)
|
||||
attn_metadata.decode = FIDecode(wrapper=decode_wrapper)
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
@ -1104,6 +1231,9 @@ class FlashInferImpl(AttentionImpl):
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
|
||||
decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode)
|
||||
|
||||
# The attn+quant fusion happens when output_scale is provided.
|
||||
if output_scale is None:
|
||||
assert output_block_scale is None, (
|
||||
@ -1113,8 +1243,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert attn_metadata.q_data_type == FP8_DTYPE, (
|
||||
"Query must be FP8 when attn+quant fusion happened."
|
||||
)
|
||||
assert (
|
||||
attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm
|
||||
assert (attn_metadata.num_prefills == 0 or prefill_use_trtllm) and (
|
||||
attn_metadata.num_decodes == 0 or decode_use_trtllm
|
||||
), "Must use TRT-LLM attn"
|
||||
|
||||
if output.dtype == FP8_DTYPE:
|
||||
@ -1191,22 +1321,25 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
# When using spec decoding, num_decodes can be < num_decode_tokens
|
||||
# because some decode requests may have more than one query token.
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
|
||||
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
||||
kv_cache_permute = kv_cache.permute(*stride_order)
|
||||
|
||||
use_dcp = self.dcp_world_size > 1
|
||||
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back.
|
||||
if num_prefill_tokens > 0:
|
||||
prefill_wrapper = attn_metadata.prefill_wrapper
|
||||
prefill_query = query[num_decode_tokens:]
|
||||
assert prefill_query.shape[0] == num_prefill_tokens
|
||||
assert prefill_wrapper is not None
|
||||
|
||||
if not attn_metadata.prefill_use_trtllm:
|
||||
if self.dcp_world_size > 1:
|
||||
if not prefill_use_trtllm:
|
||||
assert isinstance(attn_metadata.prefill, FIPrefill)
|
||||
prefill_wrapper = attn_metadata.prefill.wrapper
|
||||
assert prefill_wrapper is not None
|
||||
if use_dcp:
|
||||
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
|
||||
assert prefill_wrapper._context._window_left == self.window_left
|
||||
assert prefill_wrapper._context._logits_soft_cap == (
|
||||
@ -1247,11 +1380,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
else:
|
||||
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
|
||||
# prefill_query may be non-contiguous
|
||||
prefill_query = prefill_query.contiguous()
|
||||
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||
block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:]
|
||||
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
|
||||
block_tables_prefill = attn_metadata.prefill.block_tables
|
||||
seq_lens_prefill = attn_metadata.prefill.seq_lens
|
||||
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
@ -1298,13 +1432,13 @@ class FlashInferImpl(AttentionImpl):
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=mock_block_table,
|
||||
seq_lens=seq_lens_prefill,
|
||||
max_q_len=attn_metadata.max_q_len_prefill,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
max_q_len=attn_metadata.prefill.max_q_len,
|
||||
max_kv_len=attn_metadata.prefill.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
|
||||
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
|
||||
cum_seq_lens_q=attn_metadata.prefill.cum_seq_lens_q,
|
||||
cum_seq_lens_kv=attn_metadata.prefill.cum_seq_lens_kv,
|
||||
window_left=self.window_left,
|
||||
sinks=self.sinks,
|
||||
o_sf_scale=self.o_sf_scale,
|
||||
@ -1312,17 +1446,18 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
if num_decode_tokens > 0:
|
||||
decode_wrapper = attn_metadata.decode_wrapper
|
||||
decode_query = query[:num_decode_tokens]
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
assert decode_wrapper is not None
|
||||
|
||||
if not attn_metadata.decode_use_trtllm:
|
||||
if not decode_use_trtllm:
|
||||
assert isinstance(attn_metadata.decode, FIDecode)
|
||||
decode_wrapper = attn_metadata.decode.wrapper
|
||||
assert decode_wrapper is not None
|
||||
assert decode_wrapper._window_left == self.window_left
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
|
||||
assert decode_wrapper._sm_scale == self.scale
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
if use_dcp:
|
||||
decode_query = get_dcp_group().all_gather(
|
||||
decode_query.contiguous(), dim=-2
|
||||
)
|
||||
@ -1357,12 +1492,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
else:
|
||||
# decode_query may be non-contiguous
|
||||
assert isinstance(attn_metadata.decode, TRTLLMDecode)
|
||||
decode_query = decode_query.contiguous()
|
||||
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||
block_tables_decode = attn_metadata.block_table_tensor[
|
||||
:num_decode_tokens
|
||||
]
|
||||
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
|
||||
block_tables_decode = attn_metadata.decode.block_tables
|
||||
seq_lens_decode = attn_metadata.decode.seq_lens
|
||||
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
@ -1397,7 +1531,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables_decode,
|
||||
seq_lens=seq_lens_decode,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
max_seq_len=attn_metadata.decode.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
window_left=self.window_left,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user