mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 22:37:07 +08:00
[Cleanup] Refactor FlashInferMetadataBuilder (#29128)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
6ca74bc11a
commit
d6b3d39b6d
@ -16,6 +16,7 @@ from flashinfer import (
|
|||||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||||
from flashinfer.utils import FP4Tensor
|
from flashinfer.utils import FP4Tensor
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
@ -59,6 +60,7 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
|
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
|
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
|
||||||
|
|
||||||
@ -181,7 +183,6 @@ class BatchDCPPrefillWrapper:
|
|||||||
paged_kv_indptr_cpu: torch.Tensor,
|
paged_kv_indptr_cpu: torch.Tensor,
|
||||||
paged_kv_indices: torch.Tensor,
|
paged_kv_indices: torch.Tensor,
|
||||||
paged_kv_last_page_len_cpu: torch.Tensor,
|
paged_kv_last_page_len_cpu: torch.Tensor,
|
||||||
prefill_start: int,
|
|
||||||
page_size: int,
|
page_size: int,
|
||||||
num_qo_heads: int,
|
num_qo_heads: int,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
@ -200,7 +201,7 @@ class BatchDCPPrefillWrapper:
|
|||||||
qo_indptr_cpu,
|
qo_indptr_cpu,
|
||||||
paged_kv_indptr_cpu,
|
paged_kv_indptr_cpu,
|
||||||
paged_kv_indices,
|
paged_kv_indices,
|
||||||
paged_kv_last_page_len_cpu[prefill_start:],
|
paged_kv_last_page_len_cpu,
|
||||||
num_qo_heads * dcp_world_size,
|
num_qo_heads * dcp_world_size,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
@ -380,40 +381,103 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashInferMetadata:
|
class FIPrefill:
|
||||||
num_actual_tokens: int # Number of tokens excluding padding.
|
"""Metadata for the native FlashInfer prefill pathway (non-TRTLLM)."""
|
||||||
|
|
||||||
# The data type of the query
|
wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper
|
||||||
q_data_type: torch.dtype
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FIDecode:
|
||||||
|
"""Metadata for the native FlashInfer decode pathway (non-TRTLLM)."""
|
||||||
|
|
||||||
|
wrapper: BatchDecodeWithPagedKVCacheWrapper
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TRTLLMPrefill:
|
||||||
|
"""Metadata for the TRTLLM prefill pathway."""
|
||||||
|
|
||||||
|
block_tables: torch.Tensor
|
||||||
|
"""
|
||||||
|
The slice of the block table tensor corresponding *only* to prefill requests.
|
||||||
|
Shape: [num_prefills, max_num_blocks_per_seq]
|
||||||
|
"""
|
||||||
|
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
"""
|
||||||
|
The slice of the sequence lengths tensor corresponding *only* to prefill requests.
|
||||||
|
Shape: [num_prefills]
|
||||||
|
"""
|
||||||
|
|
||||||
|
cum_seq_lens_q: torch.Tensor
|
||||||
|
cum_seq_lens_kv: torch.Tensor
|
||||||
|
|
||||||
|
max_q_len: int
|
||||||
|
"""
|
||||||
|
The maximum query length *among prefill requests*.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_seq_len: int
|
||||||
|
"""The maximum sequence length for KV Cache."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TRTLLMDecode:
|
||||||
|
"""Metadata for the TRTLLM decode pathway."""
|
||||||
|
|
||||||
|
block_tables: torch.Tensor
|
||||||
|
"""
|
||||||
|
The slice of the block table tensor corresponding *only* to decode requests.
|
||||||
|
Shape: [num_decodes, max_num_blocks_per_seq]
|
||||||
|
"""
|
||||||
|
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
"""
|
||||||
|
The slice of the sequence lengths tensor corresponding *only* to decode requests.
|
||||||
|
Shape: [num_decodes]
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_seq_len: int
|
||||||
|
"""The maximum sequence length for KV Cache."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlashInferMetadata:
|
||||||
|
num_actual_tokens: int
|
||||||
|
"""Total number of tokens in the batch (excluding padding)."""
|
||||||
|
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
|
"""Tensor for writing K/V to the cache. Shape: [num_actual_tokens]"""
|
||||||
|
|
||||||
# For flashinfer trtllm batch decode
|
q_data_type: torch.dtype
|
||||||
max_q_len: int
|
|
||||||
max_q_len_prefill: int
|
|
||||||
max_seq_len: int
|
|
||||||
seq_lens: torch.Tensor
|
|
||||||
block_table_tensor: torch.Tensor
|
|
||||||
prefill_use_trtllm: bool
|
|
||||||
decode_use_trtllm: bool
|
|
||||||
|
|
||||||
# For handling prefill decode split
|
|
||||||
num_decodes: int
|
num_decodes: int
|
||||||
num_decode_tokens: int
|
num_decode_tokens: int
|
||||||
num_prefills: int
|
num_prefills: int
|
||||||
num_prefill_tokens: int
|
num_prefill_tokens: int
|
||||||
|
|
||||||
# For cascade attention (CPU for planning).
|
prefill: FIPrefill | TRTLLMPrefill | None
|
||||||
|
"""
|
||||||
|
Holds the metadata for the prefill portion of the batch.
|
||||||
|
Will be `None` if `num_prefill_tokens == 0`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
decode: FIDecode | TRTLLMDecode | None
|
||||||
|
"""
|
||||||
|
Holds the metadata for the decode portion of the batch.
|
||||||
|
Will be `None` if `num_decode_tokens == 0`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# --- Special Case: Cascade Attention ---
|
||||||
|
|
||||||
use_cascade: bool
|
use_cascade: bool
|
||||||
|
"""
|
||||||
|
If True, the entire batch is a cascade attention call, and the
|
||||||
|
`prefill` and `decode` fields will both be None.
|
||||||
|
"""
|
||||||
|
|
||||||
prefill_wrapper: (
|
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None
|
||||||
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
|
|
||||||
) = None
|
|
||||||
decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
|
|
||||||
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
|
|
||||||
|
|
||||||
qo_indptr_gpu: torch.Tensor | None = None
|
|
||||||
paged_kv_indptr_gpu: torch.Tensor | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||||
@ -482,6 +546,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.dcp_world_size = 1
|
self.dcp_world_size = 1
|
||||||
self.dcp_rank = 0
|
self.dcp_rank = 0
|
||||||
self.dcp_kv_cache_interleave_size = 1
|
self.dcp_kv_cache_interleave_size = 1
|
||||||
|
self.use_dcp = self.dcp_world_size > 1
|
||||||
|
|
||||||
self.num_qo_heads = self.model_config.get_num_attention_heads(
|
self.num_qo_heads = self.model_config.get_num_attention_heads(
|
||||||
self.vllm_config.parallel_config
|
self.vllm_config.parallel_config
|
||||||
@ -535,34 +600,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
"sinks, please use trtllm on blackwell or flash attention on "
|
"sinks, please use trtllm on blackwell or flash attention on "
|
||||||
"earlier GPUs."
|
"earlier GPUs."
|
||||||
)
|
)
|
||||||
# Preparing persistent buffers (device-side)
|
# Preparing persistent buffers
|
||||||
self.paged_kv_indptr = torch.zeros(
|
self.pin_memory = is_pin_memory_available()
|
||||||
max_num_reqs + 1, dtype=torch.int32, device=self.device
|
self.paged_kv_indptr = self._make_buffer(max_num_reqs + 1)
|
||||||
)
|
self.paged_kv_indptr_cpu_buffer = torch.zeros_like(
|
||||||
self.paged_kv_indices = torch.zeros(
|
self.paged_kv_indptr.cpu, pin_memory=self.pin_memory
|
||||||
max_num_pages, # max num pages possible
|
) # Extra buffer for mutable paged_kv_indptr.cpu in cuda graph mode
|
||||||
dtype=torch.int32,
|
self.paged_kv_indices = self._make_buffer(max_num_pages)
|
||||||
device=self.device,
|
self.paged_kv_last_page_len = self._make_buffer(max_num_reqs)
|
||||||
)
|
|
||||||
self.paged_kv_last_page_len = torch.zeros(
|
|
||||||
max_num_reqs, dtype=torch.int32, device=self.device
|
|
||||||
)
|
|
||||||
# host-side buffer
|
|
||||||
pin_memory = is_pin_memory_available()
|
|
||||||
self.paged_kv_indptr_cpu = torch.zeros(
|
|
||||||
max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory
|
|
||||||
)
|
|
||||||
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
|
|
||||||
self.paged_kv_indptr_buffer = torch.zeros_like(
|
|
||||||
self.paged_kv_indptr_cpu, pin_memory=pin_memory
|
|
||||||
)
|
|
||||||
self.paged_kv_indices_cpu = torch.zeros(
|
|
||||||
max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory
|
|
||||||
)
|
|
||||||
self.paged_kv_last_page_len_cpu = torch.zeros(
|
|
||||||
max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory
|
|
||||||
)
|
|
||||||
self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
|
|
||||||
|
|
||||||
if self.head_dim == 256 and current_platform.is_device_capability_family(100):
|
if self.head_dim == 256 and current_platform.is_device_capability_family(100):
|
||||||
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
|
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
|
||||||
@ -573,6 +618,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
"passing --block-size 32 or --block-size 64."
|
"passing --block-size 32 or --block-size 64."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _make_buffer(
|
||||||
|
self, *size: int | torch.SymInt, dtype: torch.dtype = torch.int32
|
||||||
|
) -> CpuGpuBuffer:
|
||||||
|
return CpuGpuBuffer(
|
||||||
|
*size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
with_numpy=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override # type: ignore[misc]
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cudagraph_support(
|
def get_cudagraph_support(
|
||||||
cls: type["FlashInferMetadataBuilder"],
|
cls: type["FlashInferMetadataBuilder"],
|
||||||
@ -607,7 +664,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self,
|
self,
|
||||||
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
|
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
|
||||||
if self._prefill_wrapper is None:
|
if self._prefill_wrapper is None:
|
||||||
if self.dcp_world_size > 1:
|
if self.use_dcp:
|
||||||
self._prefill_wrapper = BatchDCPPrefillWrapper(
|
self._prefill_wrapper = BatchDCPPrefillWrapper(
|
||||||
workspace_buffer=self._get_workspace_buffer(),
|
workspace_buffer=self._get_workspace_buffer(),
|
||||||
)
|
)
|
||||||
@ -626,9 +683,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
if decode_wrapper is None:
|
if decode_wrapper is None:
|
||||||
if use_cudagraph:
|
if use_cudagraph:
|
||||||
paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1]
|
paged_kv_indptr = self.paged_kv_indptr.gpu[: batch_size + 1]
|
||||||
paged_kv_indices = self.paged_kv_indices
|
paged_kv_indices = self.paged_kv_indices.gpu
|
||||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size]
|
paged_kv_last_page_len = self.paged_kv_last_page_len.gpu[:batch_size]
|
||||||
else:
|
else:
|
||||||
paged_kv_indptr = None
|
paged_kv_indptr = None
|
||||||
paged_kv_indices = None
|
paged_kv_indices = None
|
||||||
@ -661,6 +718,60 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
)
|
)
|
||||||
return self._cascade_wrapper
|
return self._cascade_wrapper
|
||||||
|
|
||||||
|
def _compute_flashinfer_kv_metadata(
|
||||||
|
self,
|
||||||
|
num_blocks_np: np.ndarray,
|
||||||
|
seq_lens_np: np.ndarray,
|
||||||
|
block_table_tensor: torch.Tensor,
|
||||||
|
num_reqs: int,
|
||||||
|
page_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len for FlashInfer
|
||||||
|
attention.
|
||||||
|
|
||||||
|
Results are stored in self.paged_kv_indptr,
|
||||||
|
self.paged_kv_indices, self.paged_kv_last_page_len buffers.
|
||||||
|
|
||||||
|
Returns paged_kv_indices, a GPU tensor with shape [num_actual_pages].
|
||||||
|
"""
|
||||||
|
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
|
||||||
|
np.cumsum(
|
||||||
|
num_blocks_np,
|
||||||
|
dtype=np.int32,
|
||||||
|
out=self.paged_kv_indptr.np[1 : num_reqs + 1],
|
||||||
|
)
|
||||||
|
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
|
||||||
|
# after this line (e.g., for cuda graphs), we need to copy the data to
|
||||||
|
# self.paged_kv_indptr_buffer to avoid race condition.
|
||||||
|
self.paged_kv_indptr_cpu_buffer[: num_reqs + 1] = self.paged_kv_indptr.cpu[
|
||||||
|
: num_reqs + 1
|
||||||
|
]
|
||||||
|
paged_kv_indptr = self.paged_kv_indptr.gpu[: num_reqs + 1]
|
||||||
|
paged_kv_indptr.copy_(
|
||||||
|
self.paged_kv_indptr_cpu_buffer[: num_reqs + 1], non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# write self.paged_kv_indices inplace
|
||||||
|
num_actual_pages = self.paged_kv_indptr.np[num_reqs]
|
||||||
|
paged_kv_indices = self.paged_kv_indices.gpu[:num_actual_pages]
|
||||||
|
_copy_page_indices_kernel[(num_reqs,)](
|
||||||
|
paged_kv_indices,
|
||||||
|
block_table_tensor,
|
||||||
|
block_table_tensor.stride(0),
|
||||||
|
paged_kv_indptr,
|
||||||
|
BLOCK_SIZE=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
# write self.paged_kv_last_page_len_cpu inplace
|
||||||
|
paged_kv_last_page_len_np = seq_lens_np % page_size
|
||||||
|
self.paged_kv_last_page_len.np[:num_reqs] = np.where(
|
||||||
|
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
|
||||||
|
page_size,
|
||||||
|
paged_kv_last_page_len_np,
|
||||||
|
)
|
||||||
|
return paged_kv_indices
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
@ -678,98 +789,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
page_size = self.page_size
|
page_size = self.page_size
|
||||||
max_q_len = common_attn_metadata.max_query_len
|
|
||||||
max_seq_len = common_attn_metadata.max_seq_len
|
max_seq_len = common_attn_metadata.max_seq_len
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
|
qo_indptr = common_attn_metadata.query_start_loc
|
||||||
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
|
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
|
|
||||||
if self.dcp_world_size > 1:
|
# Step 1: Decide which dispatch modes to use:
|
||||||
if num_prefills > 0:
|
# - Cascade attention (distinct mode)
|
||||||
qo_indptr_prefill_cpu = (
|
# - Prefill (FI native or TRTLLM)
|
||||||
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
|
# - Decode (FI native or TRTLLM)
|
||||||
)
|
|
||||||
query_lens_prefill_cpu = (
|
|
||||||
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
|
|
||||||
)
|
|
||||||
seq_lens_cpu[num_decodes:] = (
|
|
||||||
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
|
|
||||||
)
|
|
||||||
|
|
||||||
seq_lens_cpu = get_dcp_local_seq_lens(
|
|
||||||
seq_lens_cpu,
|
|
||||||
self.dcp_world_size,
|
|
||||||
self.dcp_rank,
|
|
||||||
self.dcp_kv_cache_interleave_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
seq_lens_np = seq_lens_cpu.numpy()
|
|
||||||
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
|
|
||||||
|
|
||||||
use_cascade = common_prefix_len > 0
|
use_cascade = common_prefix_len > 0
|
||||||
if use_cascade:
|
|
||||||
# Grab the blocks of the shared prefix from the first request.
|
|
||||||
assert common_prefix_len % page_size == 0
|
|
||||||
num_common_kv_blocks = common_prefix_len // page_size
|
|
||||||
|
|
||||||
# Create CPU versions directly for cascade (no GPU versions needed)
|
|
||||||
shared_qo_indptr_cpu = torch.tensor(
|
|
||||||
[0, num_actual_tokens], dtype=torch.int32, device="cpu"
|
|
||||||
)
|
|
||||||
shared_kv_page_indptr_cpu = torch.tensor(
|
|
||||||
[0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
|
|
||||||
)
|
|
||||||
shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
|
|
||||||
shared_kv_last_page_len_cpu = torch.tensor(
|
|
||||||
[page_size], dtype=torch.int32, device="cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove the blocks of the shared prefix from all requests.
|
|
||||||
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
|
||||||
num_blocks_np -= num_common_kv_blocks
|
|
||||||
else:
|
|
||||||
shared_qo_indptr_cpu = None
|
|
||||||
shared_kv_page_indptr_cpu = None
|
|
||||||
shared_kv_page_indices_cpu = None
|
|
||||||
shared_kv_last_page_len_cpu = None
|
|
||||||
|
|
||||||
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
|
|
||||||
np.cumsum(
|
|
||||||
num_blocks_np,
|
|
||||||
dtype=np.int32,
|
|
||||||
out=self.paged_kv_indptr_np[1 : num_reqs + 1],
|
|
||||||
)
|
|
||||||
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
|
|
||||||
# after this line (e.g., for cuda graphs), we need to copy the data to
|
|
||||||
# self.paged_kv_indptr_buffer to avoid race condition.
|
|
||||||
self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[
|
|
||||||
: num_reqs + 1
|
|
||||||
]
|
|
||||||
paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1]
|
|
||||||
paged_kv_indptr.copy_(
|
|
||||||
self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# write self.paged_kv_indices inplace
|
|
||||||
num_actual_pages = self.paged_kv_indptr_np[num_reqs]
|
|
||||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
|
||||||
_copy_page_indices_kernel[(num_reqs,)](
|
|
||||||
paged_kv_indices,
|
|
||||||
block_table_tensor,
|
|
||||||
block_table_tensor.stride(0),
|
|
||||||
paged_kv_indptr,
|
|
||||||
BLOCK_SIZE=1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
# write self.paged_kv_last_page_len_cpu inplace
|
|
||||||
paged_kv_last_page_len_np = seq_lens_np % page_size
|
|
||||||
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
|
|
||||||
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
|
|
||||||
page_size,
|
|
||||||
paged_kv_last_page_len_np,
|
|
||||||
)
|
|
||||||
|
|
||||||
uses_spec_reorder = self.reorder_batch_threshold > 1
|
uses_spec_reorder = self.reorder_batch_threshold > 1
|
||||||
prefill_use_trtllm = use_trtllm_attention(
|
prefill_use_trtllm = use_trtllm_attention(
|
||||||
self.num_qo_heads,
|
self.num_qo_heads,
|
||||||
@ -788,7 +818,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
|
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
|
||||||
)
|
)
|
||||||
|
|
||||||
if not (prefill_use_trtllm and decode_use_trtllm):
|
all_uses_trtllm = (num_prefills == 0 or prefill_use_trtllm) and (
|
||||||
|
num_decodes == 0 or decode_use_trtllm
|
||||||
|
)
|
||||||
|
is_only_trtllm_decode = num_prefills == 0 and (
|
||||||
|
num_decodes > 0 and decode_use_trtllm
|
||||||
|
)
|
||||||
|
|
||||||
|
if not all_uses_trtllm:
|
||||||
if self.has_sinks:
|
if self.has_sinks:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FlashInfer backend currently does not support attention "
|
"FlashInfer backend currently does not support attention "
|
||||||
@ -813,28 +850,102 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# fall back to model dtype.
|
# fall back to model dtype.
|
||||||
self.q_data_type = self.model_config.dtype
|
self.q_data_type = self.model_config.dtype
|
||||||
|
|
||||||
|
# Step 2: Initialize the output metadata
|
||||||
|
# Leave prefill/decode/cascade_wrapper empty, to be populated
|
||||||
|
# case by case depending on the batch contents and backend selection.
|
||||||
attn_metadata = FlashInferMetadata(
|
attn_metadata = FlashInferMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
q_data_type=self.q_data_type,
|
|
||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
max_q_len=max_q_len,
|
q_data_type=self.q_data_type,
|
||||||
max_q_len_prefill=max_q_len,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
block_table_tensor=block_table_tensor,
|
|
||||||
prefill_use_trtllm=prefill_use_trtllm,
|
|
||||||
decode_use_trtllm=decode_use_trtllm,
|
|
||||||
num_decodes=num_decodes,
|
num_decodes=num_decodes,
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
use_cascade=use_cascade,
|
use_cascade=use_cascade,
|
||||||
|
prefill=None,
|
||||||
|
decode=None,
|
||||||
|
cascade_wrapper=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs]
|
# Guard access to seq_lens_cpu, which may not always be needed
|
||||||
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
|
# and can be expensive to retrieve in async mode.
|
||||||
|
needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
|
||||||
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
|
||||||
|
seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
|
||||||
|
num_blocks_np = (
|
||||||
|
(seq_lens_np + (page_size - 1)) // page_size
|
||||||
|
if seq_lens_np is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adjust seq_lens_cpu for DCP
|
||||||
|
if self.use_dcp:
|
||||||
|
assert seq_lens_cpu is not None
|
||||||
|
if num_prefills > 0:
|
||||||
|
qo_indptr_prefill_cpu = (
|
||||||
|
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
|
||||||
|
)
|
||||||
|
query_lens_prefill_cpu = (
|
||||||
|
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
|
||||||
|
)
|
||||||
|
seq_lens_cpu[num_decodes:] = (
|
||||||
|
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_lens_cpu = get_dcp_local_seq_lens(
|
||||||
|
seq_lens_cpu,
|
||||||
|
self.dcp_world_size,
|
||||||
|
self.dcp_rank,
|
||||||
|
self.dcp_kv_cache_interleave_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adjust num_block_np for cascade attention
|
||||||
|
if use_cascade:
|
||||||
|
assert num_blocks_np is not None
|
||||||
|
assert common_prefix_len % page_size == 0
|
||||||
|
num_common_kv_blocks = common_prefix_len // page_size
|
||||||
|
num_blocks_np -= num_common_kv_blocks
|
||||||
|
|
||||||
|
# Compute paged_kv_indices if necessary
|
||||||
|
needs_paged_kv_indices = use_cascade or not is_only_trtllm_decode
|
||||||
|
if needs_paged_kv_indices:
|
||||||
|
assert num_blocks_np is not None
|
||||||
|
assert seq_lens_np is not None
|
||||||
|
paged_kv_indices = self._compute_flashinfer_kv_metadata(
|
||||||
|
num_blocks_np,
|
||||||
|
seq_lens_np,
|
||||||
|
block_table_tensor,
|
||||||
|
num_reqs,
|
||||||
|
page_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
paged_kv_indices = None
|
||||||
|
|
||||||
|
# Early-out for cascade attention
|
||||||
|
if use_cascade:
|
||||||
|
# Grab the blocks of the shared prefix from the first request.
|
||||||
|
num_common_kv_blocks = common_prefix_len // page_size
|
||||||
|
|
||||||
|
# Create CPU versions directly for cascade (no GPU versions needed)
|
||||||
|
shared_qo_indptr_cpu = torch.tensor(
|
||||||
|
[0, num_actual_tokens], dtype=torch.int32, device="cpu"
|
||||||
|
)
|
||||||
|
shared_kv_page_indptr_cpu = torch.tensor(
|
||||||
|
[0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
|
||||||
|
)
|
||||||
|
shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
|
||||||
|
shared_kv_last_page_len_cpu = torch.tensor(
|
||||||
|
[page_size], dtype=torch.int32, device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove the blocks of the shared prefix from all requests.
|
||||||
|
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
||||||
|
num_blocks_np -= num_common_kv_blocks
|
||||||
|
|
||||||
|
assert paged_kv_indices is not None
|
||||||
|
paged_kv_indptr_cpu = self.paged_kv_indptr.cpu[: 1 + num_reqs]
|
||||||
|
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len.cpu[:num_reqs]
|
||||||
|
|
||||||
if attn_metadata.use_cascade:
|
|
||||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||||
attn_metadata.cascade_wrapper.plan(
|
attn_metadata.cascade_wrapper.plan(
|
||||||
[shared_qo_indptr_cpu, qo_indptr_cpu],
|
[shared_qo_indptr_cpu, qo_indptr_cpu],
|
||||||
@ -852,91 +963,107 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
kv_data_type=self.kv_cache_dtype,
|
kv_data_type=self.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
else:
|
return attn_metadata
|
||||||
# Regular attention (common case).
|
|
||||||
# Decodes are at the front and prefills are at the back.
|
|
||||||
num_prefills = attn_metadata.num_prefills
|
|
||||||
num_decodes = attn_metadata.num_decodes
|
|
||||||
if num_prefills > 0:
|
|
||||||
# Decodes are first so prefills start after the last decode
|
|
||||||
prefill_start = num_decodes
|
|
||||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
|
||||||
assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
|
|
||||||
assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
|
|
||||||
assert (
|
|
||||||
paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills
|
|
||||||
)
|
|
||||||
# Since prefill_wrapper.run() will be called with
|
|
||||||
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
|
||||||
# to be relative to the start of the prefill queries.
|
|
||||||
qo_indptr_cpu = (
|
|
||||||
qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
|
|
||||||
)
|
|
||||||
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
|
|
||||||
|
|
||||||
# Recompute max_q_len for the slice of requests we are using
|
# Step 3: Handle prefill and decode pathways case by case
|
||||||
# for prefills. This can be different from max_q_len when
|
## PREFILL PATHWAY
|
||||||
# we have a non-uniform batch with some short decodes offloaded
|
if num_prefills > 0:
|
||||||
# to the prefill pathway
|
# Slices for shared prefill metadata
|
||||||
query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
|
prefill_start = num_decodes
|
||||||
attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item())
|
qo_indptr_prefill_cpu = (
|
||||||
|
qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
|
||||||
|
)
|
||||||
|
assert qo_indptr_prefill_cpu.shape[0] == num_prefills + 1
|
||||||
|
|
||||||
if not attn_metadata.prefill_use_trtllm:
|
if prefill_use_trtllm:
|
||||||
if self.dcp_world_size > 1:
|
# Create GPU versions
|
||||||
assert isinstance(
|
qo_indptr_prefill_gpu = (
|
||||||
attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper
|
qo_indptr[prefill_start:] - qo_indptr[prefill_start]
|
||||||
)
|
)
|
||||||
attn_metadata.prefill_wrapper.plan(
|
paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[
|
||||||
qo_indptr_cpu=qo_indptr_cpu,
|
prefill_start : num_reqs + 1
|
||||||
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
|
]
|
||||||
paged_kv_indices=paged_kv_indices,
|
# Compute max_q_len for prefill requests
|
||||||
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
|
query_lens_prefill_cpu = (
|
||||||
prefill_start=prefill_start,
|
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
|
||||||
page_size=self.page_size,
|
)
|
||||||
num_qo_heads=self.num_qo_heads,
|
max_q_len_prefill = int(query_lens_prefill_cpu.max().item())
|
||||||
dcp_world_size=self.dcp_world_size,
|
attn_metadata.prefill = TRTLLMPrefill(
|
||||||
num_kv_heads=self.num_kv_heads,
|
block_tables=block_table_tensor[prefill_start:],
|
||||||
head_dim=self.head_dim,
|
seq_lens=seq_lens[prefill_start:],
|
||||||
sm_scale=self.sm_scale,
|
cum_seq_lens_q=qo_indptr_prefill_gpu,
|
||||||
window_left=self.window_left,
|
cum_seq_lens_kv=paged_kv_indptr_prefill_gpu,
|
||||||
logits_soft_cap=self.logits_soft_cap,
|
max_q_len=max_q_len_prefill,
|
||||||
q_data_type=self.q_data_type,
|
max_seq_len=max_seq_len,
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
)
|
||||||
prefill_fixed_split_size=self.prefill_fixed_split_size,
|
else:
|
||||||
disable_split_kv=self.disable_split_kv,
|
prefill_wrapper = self._get_prefill_wrapper()
|
||||||
)
|
# Slicing CPU buffers that are only needed for FI native prefills
|
||||||
else:
|
paged_kv_last_page_len_prefill_cpu = self.paged_kv_last_page_len.cpu[
|
||||||
assert isinstance(
|
prefill_start:num_reqs
|
||||||
attn_metadata.prefill_wrapper,
|
]
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
assert paged_kv_last_page_len_prefill_cpu.shape[0] == num_prefills
|
||||||
)
|
paged_kv_indptr_prefill_cpu = self.paged_kv_indptr.cpu[
|
||||||
attn_metadata.prefill_wrapper.plan(
|
prefill_start : num_reqs + 1
|
||||||
qo_indptr_cpu,
|
]
|
||||||
paged_kv_indptr_cpu,
|
assert paged_kv_indptr_prefill_cpu.shape[0] == num_prefills + 1
|
||||||
paged_kv_indices,
|
if self.use_dcp:
|
||||||
paged_kv_last_page_len_cpu[prefill_start:],
|
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
|
||||||
self.num_qo_heads,
|
prefill_wrapper.plan(
|
||||||
self.num_kv_heads,
|
qo_indptr_cpu=qo_indptr_prefill_cpu,
|
||||||
self.head_dim,
|
paged_kv_indptr_cpu=paged_kv_indptr_prefill_cpu,
|
||||||
self.page_size,
|
paged_kv_indices=paged_kv_indices,
|
||||||
causal=True,
|
paged_kv_last_page_len_cpu=paged_kv_last_page_len_prefill_cpu,
|
||||||
sm_scale=self.sm_scale,
|
page_size=self.page_size,
|
||||||
window_left=self.window_left,
|
num_qo_heads=self.num_qo_heads,
|
||||||
logits_soft_cap=self.logits_soft_cap,
|
dcp_world_size=self.dcp_world_size,
|
||||||
q_data_type=self.q_data_type,
|
num_kv_heads=self.num_kv_heads,
|
||||||
kv_data_type=self.kv_cache_dtype,
|
head_dim=self.head_dim,
|
||||||
fixed_split_size=self.prefill_fixed_split_size,
|
sm_scale=self.sm_scale,
|
||||||
disable_split_kv=self.disable_split_kv,
|
window_left=self.window_left,
|
||||||
)
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
|
prefill_fixed_split_size=self.prefill_fixed_split_size,
|
||||||
|
disable_split_kv=self.disable_split_kv,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
|
assert isinstance(
|
||||||
self.device, non_blocking=True
|
prefill_wrapper,
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
|
prefill_wrapper.plan(
|
||||||
self.device, non_blocking=True
|
qo_indptr_prefill_cpu,
|
||||||
|
paged_kv_indptr_prefill_cpu,
|
||||||
|
paged_kv_indices,
|
||||||
|
paged_kv_last_page_len_prefill_cpu,
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.page_size,
|
||||||
|
causal=True,
|
||||||
|
sm_scale=self.sm_scale,
|
||||||
|
window_left=self.window_left,
|
||||||
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
kv_data_type=self.kv_cache_dtype,
|
||||||
|
fixed_split_size=self.prefill_fixed_split_size,
|
||||||
|
disable_split_kv=self.disable_split_kv,
|
||||||
)
|
)
|
||||||
|
attn_metadata.prefill = FIPrefill(wrapper=prefill_wrapper)
|
||||||
|
|
||||||
if num_decodes > 0:
|
## DECODE PATHWAY
|
||||||
|
if num_decodes > 0:
|
||||||
|
if decode_use_trtllm:
|
||||||
|
assert num_decode_tokens % num_decodes == 0, (
|
||||||
|
"TRTLLM decode requires uniform query lengths per request."
|
||||||
|
)
|
||||||
|
attn_metadata.decode = TRTLLMDecode(
|
||||||
|
block_tables=block_table_tensor[:num_decodes],
|
||||||
|
seq_lens=seq_lens[:num_decodes],
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
pure_decode = num_prefills == 0
|
pure_decode = num_prefills == 0
|
||||||
use_cudagraph = (
|
use_cudagraph = (
|
||||||
self.enable_cuda_graph
|
self.enable_cuda_graph
|
||||||
@ -945,33 +1072,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
)
|
)
|
||||||
num_input_tokens = num_decode_tokens
|
num_input_tokens = num_decode_tokens
|
||||||
|
|
||||||
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
decode_wrapper = self._get_decode_wrapper(
|
||||||
num_input_tokens, use_cudagraph
|
num_input_tokens, use_cudagraph
|
||||||
)
|
)
|
||||||
if not attn_metadata.decode_use_trtllm:
|
# Use the persistent buffer with padding length,
|
||||||
# Use the persistent buffer with padding length,
|
# instead of the same address but chunked version
|
||||||
# instead of the same address but chunked version
|
# in atten_metadata when using cudagraph.
|
||||||
# in atten_metadata when using cudagraph.
|
fast_plan_decode(
|
||||||
fast_plan_decode(
|
decode_wrapper,
|
||||||
attn_metadata.decode_wrapper,
|
self.paged_kv_indptr.cpu[: num_input_tokens + 1],
|
||||||
self.paged_kv_indptr_cpu[: num_input_tokens + 1],
|
paged_kv_indices,
|
||||||
paged_kv_indices,
|
self.paged_kv_last_page_len.cpu[:num_input_tokens],
|
||||||
self.paged_kv_last_page_len_cpu[:num_input_tokens],
|
seq_lens_cpu[:num_input_tokens],
|
||||||
seq_lens_cpu[:num_input_tokens],
|
self.num_qo_heads * self.dcp_world_size,
|
||||||
self.num_qo_heads * self.dcp_world_size,
|
self.num_kv_heads,
|
||||||
self.num_kv_heads,
|
self.head_dim,
|
||||||
self.head_dim,
|
self.page_size,
|
||||||
self.page_size,
|
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
pos_encoding_mode="NONE",
|
||||||
pos_encoding_mode="NONE",
|
sm_scale=self.sm_scale,
|
||||||
sm_scale=self.sm_scale,
|
window_left=self.window_left,
|
||||||
window_left=self.window_left,
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
logits_soft_cap=self.logits_soft_cap,
|
q_data_type=self.q_data_type,
|
||||||
q_data_type=self.q_data_type,
|
kv_data_type=self.kv_cache_dtype,
|
||||||
kv_data_type=self.kv_cache_dtype,
|
fixed_split_size=self.decode_fixed_split_size,
|
||||||
fixed_split_size=self.decode_fixed_split_size,
|
disable_split_kv=self.disable_split_kv,
|
||||||
disable_split_kv=self.disable_split_kv,
|
)
|
||||||
)
|
attn_metadata.decode = FIDecode(wrapper=decode_wrapper)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||||
@ -1104,6 +1231,9 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
if self.bmm2_scale is None:
|
if self.bmm2_scale is None:
|
||||||
self.bmm2_scale = layer._v_scale_float
|
self.bmm2_scale = layer._v_scale_float
|
||||||
|
|
||||||
|
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
|
||||||
|
decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode)
|
||||||
|
|
||||||
# The attn+quant fusion happens when output_scale is provided.
|
# The attn+quant fusion happens when output_scale is provided.
|
||||||
if output_scale is None:
|
if output_scale is None:
|
||||||
assert output_block_scale is None, (
|
assert output_block_scale is None, (
|
||||||
@ -1113,8 +1243,8 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
assert attn_metadata.q_data_type == FP8_DTYPE, (
|
assert attn_metadata.q_data_type == FP8_DTYPE, (
|
||||||
"Query must be FP8 when attn+quant fusion happened."
|
"Query must be FP8 when attn+quant fusion happened."
|
||||||
)
|
)
|
||||||
assert (
|
assert (attn_metadata.num_prefills == 0 or prefill_use_trtllm) and (
|
||||||
attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm
|
attn_metadata.num_decodes == 0 or decode_use_trtllm
|
||||||
), "Must use TRT-LLM attn"
|
), "Must use TRT-LLM attn"
|
||||||
|
|
||||||
if output.dtype == FP8_DTYPE:
|
if output.dtype == FP8_DTYPE:
|
||||||
@ -1191,22 +1321,25 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
|
|
||||||
# When using spec decoding, num_decodes can be < num_decode_tokens
|
# When using spec decoding, num_decodes can be < num_decode_tokens
|
||||||
# because some decode requests may have more than one query token.
|
# because some decode requests may have more than one query token.
|
||||||
num_decodes = attn_metadata.num_decodes
|
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
|
||||||
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
||||||
kv_cache_permute = kv_cache.permute(*stride_order)
|
kv_cache_permute = kv_cache.permute(*stride_order)
|
||||||
|
|
||||||
|
use_dcp = self.dcp_world_size > 1
|
||||||
|
|
||||||
# Regular attention (common case).
|
# Regular attention (common case).
|
||||||
# Decodes are at the front and prefills are at the back.
|
# Decodes are at the front and prefills are at the back.
|
||||||
if num_prefill_tokens > 0:
|
if num_prefill_tokens > 0:
|
||||||
prefill_wrapper = attn_metadata.prefill_wrapper
|
|
||||||
prefill_query = query[num_decode_tokens:]
|
prefill_query = query[num_decode_tokens:]
|
||||||
assert prefill_query.shape[0] == num_prefill_tokens
|
assert prefill_query.shape[0] == num_prefill_tokens
|
||||||
assert prefill_wrapper is not None
|
|
||||||
|
|
||||||
if not attn_metadata.prefill_use_trtllm:
|
if not prefill_use_trtllm:
|
||||||
if self.dcp_world_size > 1:
|
assert isinstance(attn_metadata.prefill, FIPrefill)
|
||||||
|
prefill_wrapper = attn_metadata.prefill.wrapper
|
||||||
|
assert prefill_wrapper is not None
|
||||||
|
if use_dcp:
|
||||||
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
|
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
|
||||||
assert prefill_wrapper._context._window_left == self.window_left
|
assert prefill_wrapper._context._window_left == self.window_left
|
||||||
assert prefill_wrapper._context._logits_soft_cap == (
|
assert prefill_wrapper._context._logits_soft_cap == (
|
||||||
@ -1247,11 +1380,12 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
out=output[num_decode_tokens:],
|
out=output[num_decode_tokens:],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
|
||||||
# prefill_query may be non-contiguous
|
# prefill_query may be non-contiguous
|
||||||
prefill_query = prefill_query.contiguous()
|
prefill_query = prefill_query.contiguous()
|
||||||
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||||
block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:]
|
block_tables_prefill = attn_metadata.prefill.block_tables
|
||||||
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
|
seq_lens_prefill = attn_metadata.prefill.seq_lens
|
||||||
|
|
||||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||||
assert get_kv_cache_layout() == "HND"
|
assert get_kv_cache_layout() == "HND"
|
||||||
@ -1298,13 +1432,13 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
workspace_buffer=workspace_buffer,
|
workspace_buffer=workspace_buffer,
|
||||||
block_tables=mock_block_table,
|
block_tables=mock_block_table,
|
||||||
seq_lens=seq_lens_prefill,
|
seq_lens=seq_lens_prefill,
|
||||||
max_q_len=attn_metadata.max_q_len_prefill,
|
max_q_len=attn_metadata.prefill.max_q_len,
|
||||||
max_kv_len=attn_metadata.max_seq_len,
|
max_kv_len=attn_metadata.prefill.max_seq_len,
|
||||||
bmm1_scale=self.bmm1_scale,
|
bmm1_scale=self.bmm1_scale,
|
||||||
bmm2_scale=self.bmm2_scale,
|
bmm2_scale=self.bmm2_scale,
|
||||||
batch_size=attn_metadata.num_prefills,
|
batch_size=attn_metadata.num_prefills,
|
||||||
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
|
cum_seq_lens_q=attn_metadata.prefill.cum_seq_lens_q,
|
||||||
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
|
cum_seq_lens_kv=attn_metadata.prefill.cum_seq_lens_kv,
|
||||||
window_left=self.window_left,
|
window_left=self.window_left,
|
||||||
sinks=self.sinks,
|
sinks=self.sinks,
|
||||||
o_sf_scale=self.o_sf_scale,
|
o_sf_scale=self.o_sf_scale,
|
||||||
@ -1312,17 +1446,18 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if num_decode_tokens > 0:
|
if num_decode_tokens > 0:
|
||||||
decode_wrapper = attn_metadata.decode_wrapper
|
|
||||||
decode_query = query[:num_decode_tokens]
|
decode_query = query[:num_decode_tokens]
|
||||||
assert decode_query.shape[0] == num_decode_tokens
|
assert decode_query.shape[0] == num_decode_tokens
|
||||||
assert decode_wrapper is not None
|
|
||||||
|
|
||||||
if not attn_metadata.decode_use_trtllm:
|
if not decode_use_trtllm:
|
||||||
|
assert isinstance(attn_metadata.decode, FIDecode)
|
||||||
|
decode_wrapper = attn_metadata.decode.wrapper
|
||||||
|
assert decode_wrapper is not None
|
||||||
assert decode_wrapper._window_left == self.window_left
|
assert decode_wrapper._window_left == self.window_left
|
||||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
|
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
|
||||||
assert decode_wrapper._sm_scale == self.scale
|
assert decode_wrapper._sm_scale == self.scale
|
||||||
|
|
||||||
if self.dcp_world_size > 1:
|
if use_dcp:
|
||||||
decode_query = get_dcp_group().all_gather(
|
decode_query = get_dcp_group().all_gather(
|
||||||
decode_query.contiguous(), dim=-2
|
decode_query.contiguous(), dim=-2
|
||||||
)
|
)
|
||||||
@ -1357,12 +1492,11 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# decode_query may be non-contiguous
|
# decode_query may be non-contiguous
|
||||||
|
assert isinstance(attn_metadata.decode, TRTLLMDecode)
|
||||||
decode_query = decode_query.contiguous()
|
decode_query = decode_query.contiguous()
|
||||||
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||||
block_tables_decode = attn_metadata.block_table_tensor[
|
block_tables_decode = attn_metadata.decode.block_tables
|
||||||
:num_decode_tokens
|
seq_lens_decode = attn_metadata.decode.seq_lens
|
||||||
]
|
|
||||||
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
|
|
||||||
|
|
||||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||||
assert get_kv_cache_layout() == "HND"
|
assert get_kv_cache_layout() == "HND"
|
||||||
@ -1397,7 +1531,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
workspace_buffer=workspace_buffer,
|
workspace_buffer=workspace_buffer,
|
||||||
block_tables=block_tables_decode,
|
block_tables=block_tables_decode,
|
||||||
seq_lens=seq_lens_decode,
|
seq_lens=seq_lens_decode,
|
||||||
max_seq_len=attn_metadata.max_seq_len,
|
max_seq_len=attn_metadata.decode.max_seq_len,
|
||||||
bmm1_scale=self.bmm1_scale,
|
bmm1_scale=self.bmm1_scale,
|
||||||
bmm2_scale=self.bmm2_scale,
|
bmm2_scale=self.bmm2_scale,
|
||||||
window_left=self.window_left,
|
window_left=self.window_left,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user