mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:45:19 +08:00
[AMD] Use Decoupled Kernel Block Size to Support AITER MLA block_size=1 (#27715)
Signed-off-by: chiangzhang <chiangzhang@tencent.com>
This commit is contained in:
parent
05c2dee7e9
commit
3fb0d90999
@ -119,14 +119,12 @@ class AttentionBackend(ABC):
|
||||
return True
|
||||
|
||||
for supported_size in cls.supported_kernel_block_sizes:
|
||||
is_multiple_of = (
|
||||
isinstance(supported_size, MultipleOf)
|
||||
and block_size % supported_size.base == 0
|
||||
)
|
||||
is_int_equal = (
|
||||
isinstance(supported_size, int) and block_size == supported_size
|
||||
)
|
||||
if is_multiple_of or is_int_equal:
|
||||
if isinstance(supported_size, MultipleOf):
|
||||
supported_size = supported_size.base
|
||||
# With hybrid_blocks feature, the framework-level block size
|
||||
# only needs to be a multiple of the kernel's requirement,
|
||||
# even if the kernel requires a fixed block_size.
|
||||
if block_size % supported_size == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ -7,9 +7,8 @@ from typing import ClassVar
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.attention.backends.abstract import AttentionLayer, MultipleOf
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
@ -22,6 +21,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA"
|
||||
@ -71,9 +72,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
)
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
max_num_pages_per_req = cdiv(
|
||||
vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size
|
||||
)
|
||||
# kernel block size is always 1.
|
||||
max_num_pages_per_req = vllm_config.model_config.max_model_len
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
|
||||
@ -82,11 +82,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.block_table_remapping = torch.zeros(
|
||||
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.paged_kv_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
@ -111,36 +106,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> AiterMLADecodeMetadata:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
# kernel block size is always 1, although the kv block size is not 1.
|
||||
device = self.device
|
||||
num_reqs = seq_lens_device.size(0)
|
||||
bs, _ = block_table_tensor.shape
|
||||
block_table_tensor = (
|
||||
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
|
||||
)
|
||||
block_table_tensor = (
|
||||
block_table_tensor
|
||||
+ torch.arange(
|
||||
0,
|
||||
page_size,
|
||||
device=block_table_tensor.device,
|
||||
dtype=block_table_tensor.dtype,
|
||||
)[None, None, :]
|
||||
)
|
||||
block_table_tensor = block_table_tensor.view(bs, -1)
|
||||
|
||||
# after remapping, we assume the block size already equals to 1
|
||||
|
||||
max_blk_size_per_req = block_table_tensor.shape[-1]
|
||||
mask = torch.arange(
|
||||
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
|
||||
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_last_page_len = seq_lens_device % page_size
|
||||
paged_kv_last_page_len = torch.where(
|
||||
paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len
|
||||
)
|
||||
paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device)
|
||||
|
||||
paged_kv_indptr = torch.cat(
|
||||
[
|
||||
@ -151,12 +126,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
|
||||
block_table_tensor, non_blocking=True
|
||||
)
|
||||
block_table_tensor = self.block_table_remapping[
|
||||
:num_reqs, :max_blk_size_per_req
|
||||
]
|
||||
|
||||
self.paged_kv_indices[:num_actual_pages].copy_(
|
||||
paged_kv_indices, non_blocking=True
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user