mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:45:49 +08:00
[Attention] Optimize FlashInfer MetadataBuilder Build call (#21137)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
526078a96c
commit
61b8cea3b4
@ -11,7 +11,8 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
set_kv_cache_layout)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
@ -212,7 +213,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
|
||||
from vllm.v1.attention.backends.flashinfer import PerLayerParameters
|
||||
|
||||
def mock_get_per_layer_parameters(vllm_config):
|
||||
def mock_get_per_layer_parameters(vllm_config, impl_cls):
|
||||
# Return mock parameters for a single layer
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
return {
|
||||
@ -297,7 +298,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
"""
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
vllm_config = create_vllm_config(model_name=model)
|
||||
vllm_config = create_vllm_config(model_name=model,
|
||||
max_model_len=max(batch_spec.seq_lens))
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
@ -419,6 +421,11 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
if backend_name == _Backend.FLASHINFER_VLLM_V1:
|
||||
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
||||
|
||||
# For FlashInfer default to HND layout and
|
||||
kv_cache_for_backend = kv_cache_for_backend.transpose(
|
||||
2, 3).contiguous().transpose(2, 3)
|
||||
set_kv_cache_layout("HND")
|
||||
|
||||
backend_output = run_attention_backend(backend_name, kv_cache_spec,
|
||||
vllm_config, device,
|
||||
common_attn_metadata,
|
||||
|
||||
@ -66,7 +66,7 @@ def create_common_attn_metadata(
|
||||
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
||||
|
||||
# Create block table (random for testing)
|
||||
max_blocks = max(batch_spec.seq_lens) // block_size + 1
|
||||
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
|
||||
block_table_tensor = torch.randint(0,
|
||||
max_block_idx,
|
||||
(batch_spec.batch_size, max_blocks),
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
|
||||
@ -158,7 +159,7 @@ class FlashInferMetadata:
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
qo_indptr: torch.Tensor
|
||||
qo_indptr_cpu: torch.Tensor
|
||||
# An example for paged_kv_indices, paged_kv_indptr:
|
||||
# request 1, page indices [0, 5, 8]
|
||||
# request 2, page indices [1, 6, 7]
|
||||
@ -167,13 +168,13 @@ class FlashInferMetadata:
|
||||
# [0, 5, 8, 1, 6, 7, 3, 4]
|
||||
# paged_kv_indptr is used to index into paged_kv_indices:
|
||||
# [0, 3, 6, 8]
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: torch.Tensor
|
||||
# The page indices of the paged kv cache
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
|
||||
paged_kv_indptr_cpu: torch.Tensor
|
||||
# The page indices of the paged kv cache (on device for plan)
|
||||
paged_kv_indices: torch.Tensor
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: torch.Tensor
|
||||
# the paged kv cache, shape: [batch_size] (CPU for plan)
|
||||
paged_kv_last_page_len_cpu: torch.Tensor
|
||||
# The number of query/output heads
|
||||
num_qo_heads: int
|
||||
# The number of key/value heads
|
||||
@ -201,22 +202,17 @@ class FlashInferMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
# For cascade attention.
|
||||
# For cascade attention (CPU for planning).
|
||||
use_cascade: bool
|
||||
shared_qo_indptr: Optional[torch.Tensor] = None
|
||||
shared_kv_page_indptr: Optional[torch.Tensor] = None
|
||||
shared_kv_page_indices: Optional[torch.Tensor] = None
|
||||
shared_kv_last_page_len: Optional[torch.Tensor] = None
|
||||
shared_qo_indptr_cpu: Optional[torch.Tensor] = None
|
||||
shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None
|
||||
shared_kv_page_indices_cpu: Optional[torch.Tensor] = None
|
||||
shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
||||
|
||||
@property
|
||||
def query_start_loc(self):
|
||||
# The GPUModelRunner expects to be able to access this property.
|
||||
return self.qo_indptr
|
||||
|
||||
def __post_init__(self):
|
||||
if self.head_dim is not None:
|
||||
FlashInferBackend.validate_head_size(self.head_dim)
|
||||
@ -238,6 +234,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.vllm_config = vllm_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
max_num_blocks_per_request = cdiv(
|
||||
vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size)
|
||||
self.block_table_arange = torch.arange(max_num_blocks_per_request,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
@ -285,21 +287,25 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
if self.global_hyperparameters is None:
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
|
||||
|
||||
if attn_metadata.use_cascade:
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
[attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
|
||||
[
|
||||
attn_metadata.shared_kv_page_indptr,
|
||||
attn_metadata.paged_kv_indptr
|
||||
attn_metadata.shared_qo_indptr_cpu,
|
||||
attn_metadata.qo_indptr_cpu
|
||||
],
|
||||
[
|
||||
attn_metadata.shared_kv_page_indices,
|
||||
attn_metadata.shared_kv_page_indptr_cpu,
|
||||
attn_metadata.paged_kv_indptr_cpu
|
||||
],
|
||||
[
|
||||
attn_metadata.shared_kv_page_indices_cpu,
|
||||
attn_metadata.paged_kv_indices
|
||||
],
|
||||
[
|
||||
attn_metadata.shared_kv_last_page_len,
|
||||
attn_metadata.paged_kv_last_page_len
|
||||
attn_metadata.shared_kv_last_page_len_cpu,
|
||||
attn_metadata.paged_kv_last_page_len_cpu
|
||||
],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
@ -320,22 +326,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# Decodes are first so prefills start after the last decode
|
||||
prefill_start = num_decodes
|
||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
||||
assert attn_metadata.qo_indptr[prefill_start:].shape[
|
||||
assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[
|
||||
0] == num_prefills + 1
|
||||
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
|
||||
assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[
|
||||
0] == num_prefills + 1
|
||||
assert attn_metadata.paged_kv_last_page_len[
|
||||
assert attn_metadata.paged_kv_last_page_len_cpu[
|
||||
prefill_start:].shape[0] == num_prefills
|
||||
# Since prefill_wrapper.run() will be called with
|
||||
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
||||
# to be relative to the start of the prefill queries.
|
||||
qo_indptr = attn_metadata.qo_indptr[
|
||||
prefill_start:] - attn_metadata.qo_indptr[prefill_start]
|
||||
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
|
||||
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr,
|
||||
attn_metadata.paged_kv_indptr[prefill_start:],
|
||||
qo_indptr_cpu,
|
||||
attn_metadata.paged_kv_indptr_cpu[prefill_start:],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len[prefill_start:],
|
||||
attn_metadata.paged_kv_last_page_len_cpu[prefill_start:],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
@ -357,9 +363,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim):
|
||||
attn_metadata.decode_wrapper.plan(
|
||||
attn_metadata.paged_kv_indptr[:num_decodes + 1],
|
||||
attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len[:num_decodes],
|
||||
attn_metadata.paged_kv_last_page_len_cpu[:num_decodes],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
@ -383,55 +389,58 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
device = self.device
|
||||
qo_indptr = common_attn_metadata.query_start_loc
|
||||
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
# Grab the blocks of the shared prefix from the first request.
|
||||
assert common_prefix_len % page_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // page_size
|
||||
shared_qo_indptr = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
shared_kv_page_indices = block_table_tensor[
|
||||
|
||||
# Create CPU versions directly for cascade (no GPU versions needed)
|
||||
shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device='cpu')
|
||||
shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks],
|
||||
dtype=torch.int32,
|
||||
device='cpu')
|
||||
shared_kv_page_indices_cpu = block_table_tensor[
|
||||
0, :num_common_kv_blocks]
|
||||
shared_kv_last_page_len = torch.tensor([page_size],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
shared_kv_last_page_len_cpu = torch.tensor([page_size],
|
||||
dtype=torch.int32,
|
||||
device='cpu')
|
||||
|
||||
# Remove the blocks of the shared prefix from all requests.
|
||||
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
||||
block_table_bounds -= num_common_kv_blocks
|
||||
block_table_bounds_cpu -= num_common_kv_blocks
|
||||
else:
|
||||
shared_qo_indptr = None
|
||||
shared_kv_page_indptr = None
|
||||
shared_kv_page_indices = None
|
||||
shared_kv_last_page_len = None
|
||||
shared_qo_indptr_cpu = None
|
||||
shared_kv_page_indptr_cpu = None
|
||||
shared_kv_page_indices_cpu = None
|
||||
shared_kv_last_page_len_cpu = None
|
||||
|
||||
mask = (torch.arange(block_table_tensor.size(1),
|
||||
dtype=block_table_tensor.dtype,
|
||||
device=block_table_tensor.device).unsqueeze(0)
|
||||
max_num_blocks = block_table_bounds_cpu.max()
|
||||
block_table_bounds = block_table_bounds_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask]
|
||||
|
||||
paged_kv_indptr = torch.cat([
|
||||
torch.zeros(1,
|
||||
dtype=block_table_bounds.dtype,
|
||||
device=block_table_bounds.device),
|
||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
||||
])
|
||||
paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1,
|
||||
dtype=torch.int32,
|
||||
device='cpu')
|
||||
paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum(
|
||||
dim=0, dtype=torch.int32)
|
||||
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
|
||||
paged_kv_last_page_len_cpu = torch.where(
|
||||
paged_kv_last_page_len_cpu == 0, page_size,
|
||||
paged_kv_last_page_len_cpu)
|
||||
cache_dtype = self.cache_config.cache_dtype
|
||||
if cache_dtype.startswith("fp8"):
|
||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
@ -440,10 +449,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
qo_indptr=qo_indptr,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
|
||||
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
|
||||
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config),
|
||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||
@ -457,14 +466,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
use_cascade=use_cascade,
|
||||
shared_qo_indptr=shared_qo_indptr,
|
||||
shared_kv_page_indptr=shared_kv_page_indptr,
|
||||
shared_kv_page_indices=shared_kv_page_indices,
|
||||
shared_kv_last_page_len=shared_kv_last_page_len,
|
||||
shared_qo_indptr_cpu=shared_qo_indptr_cpu,
|
||||
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
|
||||
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
|
||||
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table_tensor=block_table_tensor,
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
workspace_buffer=self._get_workspace_buffer(),
|
||||
)
|
||||
|
||||
self._plan(num_prefills, num_decodes, attn_metadata)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user