[Misc] Minor refactoring for FlashInfer backend (#23147)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-19 13:11:51 -07:00 committed by GitHub
parent 80141bbf2f
commit e61bac87ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,8 +10,7 @@ import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper)
from flashinfer.decode import (_get_range_buf, get_seq_lens,
trtllm_batch_decode_with_kv_cache)
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
import vllm.envs as envs
@ -142,19 +141,10 @@ class FlashInferMetadata:
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size] (CPU for plan)
paged_kv_last_page_len_cpu: torch.Tensor
# The number of query/output heads
num_qo_heads: int
# The number of key/value heads
num_kv_heads: int
# The dimension of the attention heads
head_dim: int
# Block size of vllm
page_size: int
# The data type of the paged kv cache
kv_data_type: torch.dtype
# The data type of the query
q_data_type: torch.dtype
seq_lens_cpu: torch.Tensor
slot_mapping: torch.Tensor
# For flashinfer trtllm batch decode
@ -185,10 +175,6 @@ class FlashInferMetadata:
qo_indptr_gpu: Optional[torch.Tensor] = None
paged_kv_indptr_gpu: Optional[torch.Tensor] = None
def __post_init__(self):
if self.head_dim is not None:
FlashInferBackend.validate_head_size(self.head_dim)
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
@ -201,13 +187,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.device = device
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.model_config = vllm_config.model_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape)
self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
max_num_pages_per_req = cdiv(self.model_config.max_model_len,
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
@ -221,6 +208,29 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self._decode_cudagraph_max_bs = min(
max_num_reqs, self.compilation_config.max_capture_size)
self.num_qo_heads = self.model_config.get_num_attention_heads(
self.vllm_config.parallel_config)
self.num_kv_heads = self.kv_cache_spec.num_kv_heads
self.head_dim = self.kv_cache_spec.head_size
FlashInferBackend.validate_head_size(self.head_dim)
self.page_size = self.kv_cache_spec.block_size
self.enable_fusion = (
self.compilation_config.pass_config.enable_attn_fusion)
self.q_data_type = self.model_config.dtype
self.cache_dtype = self.cache_config.cache_dtype
if self.cache_dtype.startswith("fp8"):
self.kv_cache_dtype = (
FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.cache_dtype))
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
if self.enable_fusion:
self.q_data_type = self.kv_cache_dtype
else:
self.kv_cache_dtype = self.kv_cache_spec.dtype
self.use_tensor_cores = (envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or
(self.num_qo_heads // self.num_kv_heads > 4))
self._cascade_wrapper = None # Wrapper for cascade attention
# Global hyperparameters shared by all attention layers
@ -282,14 +292,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
decode_wrapper = self._decode_wrapper
if decode_wrapper is None:
num_qo_heads = (
self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config))
num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
self.vllm_config.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
if use_cudagraph:
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
paged_kv_indices = self.paged_kv_indices
@ -306,7 +308,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr_buffer=paged_kv_indptr,
paged_kv_indices_buffer=paged_kv_indices,
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
use_tensor_cores=use_tensor_cores)
use_tensor_cores=self.use_tensor_cores)
# save the decode wrapper
if use_cudagraph:
@ -342,16 +344,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata.shared_kv_last_page_len_cpu,
attn_metadata.paged_kv_last_page_len_cpu
],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.kv_data_type,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
else:
# Regular attention (common case).
@ -383,17 +385,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata.paged_kv_indices,
attn_metadata.
paged_kv_last_page_len_cpu[prefill_start:],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.kv_data_type,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
@ -435,18 +437,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
attn_metadata.paged_kv_indices,
self.paged_kv_last_page_len_cpu[:num_input_tokens],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
attn_metadata.seq_lens_cpu[:num_input_tokens],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.kv_data_type,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
)
def build(self,
@ -458,9 +461,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
page_size = self.kv_cache_spec.block_size
page_size = self.page_size
max_q_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
max_seq_len = common_attn_metadata.seq_lens_cpu.max().item()
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
@ -495,7 +498,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
shared_kv_page_indices_cpu = None
shared_kv_last_page_len_cpu = None
max_num_blocks = block_table_bounds_cpu.max()
max_num_blocks = block_table_bounds_cpu.max().item()
block_table_bounds = block_table_bounds_cpu.to(self.device,
non_blocking=True)
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
@ -520,42 +523,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_last_page_len_cpu,
out=self.paged_kv_last_page_len_cpu[:num_reqs])
cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
cache_dtype)
else:
kv_cache_dtype = self.kv_cache_spec.dtype
config = self.vllm_config
num_qo_heads = config.model_config.get_num_attention_heads(
config.parallel_config)
num_kv_heads = self.kv_cache_spec.num_kv_heads
head_dim = self.kv_cache_spec.head_size
# Check if any layer uses sinks (requires TRTLLM attention)
has_sinks = self.global_hyperparameters.has_sinks
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
q_dtype = config.model_config.dtype
enable_fusion = config.compilation_config.pass_config.enable_attn_fusion
if cache_dtype.startswith("fp8") and enable_fusion:
q_dtype = kv_cache_dtype
prefill_use_trtllm = use_trtllm_attention(num_qo_heads,
num_kv_heads,
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_prefill_tokens,
max_seq_len,
cache_dtype,
q_dtype,
self.cache_dtype,
self.q_data_type,
is_prefill=True,
has_sinks=has_sinks)
decode_use_trtllm = use_trtllm_attention(num_qo_heads,
num_kv_heads,
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_decode_tokens,
max_seq_len,
cache_dtype,
q_dtype,
self.cache_dtype,
self.q_data_type,
is_prefill=False,
has_sinks=has_sinks)
@ -566,12 +550,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len_cpu=self.
paged_kv_last_page_len_cpu[:num_reqs],
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=page_size,
kv_data_type=kv_cache_dtype,
q_data_type=q_dtype,
q_data_type=self.q_data_type,
seq_lens_cpu=seq_lens_cpu,
slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len,
max_seq_len=max_seq_len,
@ -910,6 +890,7 @@ def fast_plan_decode(
indptr_cpu: torch.Tensor,
indices: torch.Tensor,
last_page_len_cpu: torch.Tensor,
seq_lens_cpu: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
@ -987,9 +968,6 @@ def fast_plan_decode(
kv_data_type = getattr(torch, kv_data_type) if isinstance(
kv_data_type, str) else kv_data_type
if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime "
@ -1006,12 +984,8 @@ def fast_plan_decode(
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
non_blocking=True)
indptr_host = indptr_cpu
last_page_len_host = last_page_len_cpu
if self.use_tensor_cores:
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host,
page_size)
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
try:
# Make sure we pass exactly 15 arguments for tensor core version
@ -1020,8 +994,8 @@ def fast_plan_decode(
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
indptr_cpu,
seq_lens_cpu,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
@ -1041,7 +1015,7 @@ def fast_plan_decode(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_host,
indptr_cpu,
batch_size,
num_qo_heads,
num_kv_heads,