mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 21:34:39 +08:00
[Misc] Minor refactoring for FlashInfer backend (#23147)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
80141bbf2f
commit
e61bac87ee
@ -10,8 +10,7 @@ import torch
|
||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
MultiLevelCascadeAttentionWrapper)
|
||||
from flashinfer.decode import (_get_range_buf, get_seq_lens,
|
||||
trtllm_batch_decode_with_kv_cache)
|
||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -142,19 +141,10 @@ class FlashInferMetadata:
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size] (CPU for plan)
|
||||
paged_kv_last_page_len_cpu: torch.Tensor
|
||||
# The number of query/output heads
|
||||
num_qo_heads: int
|
||||
# The number of key/value heads
|
||||
num_kv_heads: int
|
||||
# The dimension of the attention heads
|
||||
head_dim: int
|
||||
# Block size of vllm
|
||||
page_size: int
|
||||
# The data type of the paged kv cache
|
||||
kv_data_type: torch.dtype
|
||||
# The data type of the query
|
||||
q_data_type: torch.dtype
|
||||
|
||||
seq_lens_cpu: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For flashinfer trtllm batch decode
|
||||
@ -185,10 +175,6 @@ class FlashInferMetadata:
|
||||
qo_indptr_gpu: Optional[torch.Tensor] = None
|
||||
paged_kv_indptr_gpu: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.head_dim is not None:
|
||||
FlashInferBackend.validate_head_size(self.head_dim)
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
@ -201,13 +187,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self._workspace_buffer = None
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
self._decode_wrapper = None # Wrapper for decode (general shape)
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
|
||||
max_num_pages_per_req = cdiv(self.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size)
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
@ -221,6 +208,29 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self._decode_cudagraph_max_bs = min(
|
||||
max_num_reqs, self.compilation_config.max_capture_size)
|
||||
|
||||
self.num_qo_heads = self.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config)
|
||||
self.num_kv_heads = self.kv_cache_spec.num_kv_heads
|
||||
self.head_dim = self.kv_cache_spec.head_size
|
||||
FlashInferBackend.validate_head_size(self.head_dim)
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
|
||||
self.enable_fusion = (
|
||||
self.compilation_config.pass_config.enable_attn_fusion)
|
||||
self.q_data_type = self.model_config.dtype
|
||||
self.cache_dtype = self.cache_config.cache_dtype
|
||||
if self.cache_dtype.startswith("fp8"):
|
||||
self.kv_cache_dtype = (
|
||||
FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.cache_dtype))
|
||||
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
|
||||
if self.enable_fusion:
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
self.use_tensor_cores = (envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or
|
||||
(self.num_qo_heads // self.num_kv_heads > 4))
|
||||
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
# Global hyperparameters shared by all attention layers
|
||||
@ -282,14 +292,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
decode_wrapper = self._decode_wrapper
|
||||
|
||||
if decode_wrapper is None:
|
||||
num_qo_heads = (
|
||||
self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config))
|
||||
num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
|
||||
self.vllm_config.parallel_config)
|
||||
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
|
||||
if use_cudagraph:
|
||||
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
|
||||
paged_kv_indices = self.paged_kv_indices
|
||||
@ -306,7 +308,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indptr_buffer=paged_kv_indptr,
|
||||
paged_kv_indices_buffer=paged_kv_indices,
|
||||
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
|
||||
use_tensor_cores=use_tensor_cores)
|
||||
use_tensor_cores=self.use_tensor_cores)
|
||||
|
||||
# save the decode wrapper
|
||||
if use_cudagraph:
|
||||
@ -342,16 +344,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
attn_metadata.shared_kv_last_page_len_cpu,
|
||||
attn_metadata.paged_kv_last_page_len_cpu
|
||||
],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
# Regular attention (common case).
|
||||
@ -383,17 +385,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.
|
||||
paged_kv_last_page_len_cpu[prefill_start:],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
|
||||
@ -435,18 +437,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
|
||||
attn_metadata.paged_kv_indices,
|
||||
self.paged_kv_last_page_len_cpu[:num_input_tokens],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
attn_metadata.seq_lens_cpu[:num_input_tokens],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
)
|
||||
|
||||
def build(self,
|
||||
@ -458,9 +461,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
page_size = self.page_size
|
||||
max_q_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
|
||||
max_seq_len = common_attn_metadata.seq_lens_cpu.max().item()
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
@ -495,7 +498,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
shared_kv_page_indices_cpu = None
|
||||
shared_kv_last_page_len_cpu = None
|
||||
|
||||
max_num_blocks = block_table_bounds_cpu.max()
|
||||
max_num_blocks = block_table_bounds_cpu.max().item()
|
||||
block_table_bounds = block_table_bounds_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
|
||||
@ -520,42 +523,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_last_page_len_cpu,
|
||||
out=self.paged_kv_last_page_len_cpu[:num_reqs])
|
||||
|
||||
cache_dtype = self.cache_config.cache_dtype
|
||||
if cache_dtype.startswith("fp8"):
|
||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
cache_dtype)
|
||||
else:
|
||||
kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
|
||||
config = self.vllm_config
|
||||
num_qo_heads = config.model_config.get_num_attention_heads(
|
||||
config.parallel_config)
|
||||
num_kv_heads = self.kv_cache_spec.num_kv_heads
|
||||
head_dim = self.kv_cache_spec.head_size
|
||||
|
||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||
has_sinks = self.global_hyperparameters.has_sinks
|
||||
|
||||
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
|
||||
q_dtype = config.model_config.dtype
|
||||
enable_fusion = config.compilation_config.pass_config.enable_attn_fusion
|
||||
if cache_dtype.startswith("fp8") and enable_fusion:
|
||||
q_dtype = kv_cache_dtype
|
||||
|
||||
prefill_use_trtllm = use_trtllm_attention(num_qo_heads,
|
||||
num_kv_heads,
|
||||
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
num_prefill_tokens,
|
||||
max_seq_len,
|
||||
cache_dtype,
|
||||
q_dtype,
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
is_prefill=True,
|
||||
has_sinks=has_sinks)
|
||||
decode_use_trtllm = use_trtllm_attention(num_qo_heads,
|
||||
num_kv_heads,
|
||||
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
num_decode_tokens,
|
||||
max_seq_len,
|
||||
cache_dtype,
|
||||
q_dtype,
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
is_prefill=False,
|
||||
has_sinks=has_sinks)
|
||||
|
||||
@ -566,12 +550,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu=self.
|
||||
paged_kv_last_page_len_cpu[:num_reqs],
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
page_size=page_size,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
q_data_type=q_dtype,
|
||||
q_data_type=self.q_data_type,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
max_q_len=max_q_len,
|
||||
max_seq_len=max_seq_len,
|
||||
@ -910,6 +890,7 @@ def fast_plan_decode(
|
||||
indptr_cpu: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
last_page_len_cpu: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
@ -987,9 +968,6 @@ def fast_plan_decode(
|
||||
kv_data_type = getattr(torch, kv_data_type) if isinstance(
|
||||
kv_data_type, str) else kv_data_type
|
||||
|
||||
if self.use_tensor_cores:
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
|
||||
if batch_size != self._fixed_batch_size:
|
||||
raise ValueError(
|
||||
"The batch size should be fixed in cudagraph mode, the runtime "
|
||||
@ -1006,12 +984,8 @@ def fast_plan_decode(
|
||||
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
|
||||
non_blocking=True)
|
||||
|
||||
indptr_host = indptr_cpu
|
||||
last_page_len_host = last_page_len_cpu
|
||||
|
||||
if self.use_tensor_cores:
|
||||
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host,
|
||||
page_size)
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for tensor core version
|
||||
@ -1020,8 +994,8 @@ def fast_plan_decode(
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_host,
|
||||
kv_lens_arr_host,
|
||||
indptr_cpu,
|
||||
seq_lens_cpu,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
@ -1041,7 +1015,7 @@ def fast_plan_decode(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
indptr_host,
|
||||
indptr_cpu,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user