[AMD] Use Decoupled Kernel Block Size to Support AITER MLA block_size=1 (#27715)

Signed-off-by: chiangzhang <chiangzhang@tencent.com>
This commit is contained in:
Qiang Zhang 2025-11-20 10:11:52 +08:00 committed by GitHub
parent 05c2dee7e9
commit 3fb0d90999
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 46 deletions

View File

@ -119,14 +119,12 @@ class AttentionBackend(ABC):
return True
for supported_size in cls.supported_kernel_block_sizes:
is_multiple_of = (
isinstance(supported_size, MultipleOf)
and block_size % supported_size.base == 0
)
is_int_equal = (
isinstance(supported_size, int) and block_size == supported_size
)
if is_multiple_of or is_int_equal:
if isinstance(supported_size, MultipleOf):
supported_size = supported_size.base
# With hybrid_blocks feature, the framework-level block size
# only needs to be a multiple of the kernel's requirement,
# even if the kernel requires a fixed block_size.
if block_size % supported_size == 0:
return True
return False

View File

@ -7,9 +7,8 @@ from typing import ClassVar
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import AttentionLayer
from vllm.attention.backends.abstract import AttentionLayer, MultipleOf
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
@ -22,6 +21,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec
class AiterMLABackend(MLACommonBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1]
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA"
@ -71,9 +72,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
)
self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(
vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size
)
# kernel block size is always 1.
max_num_pages_per_req = vllm_config.model_config.max_model_len
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
@ -82,11 +82,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.block_table_remapping = torch.zeros(
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
dtype=torch.int32,
device=device,
)
self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
@ -111,36 +106,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
# kernel block size is always 1, although the kv block size is not 1.
device = self.device
num_reqs = seq_lens_device.size(0)
bs, _ = block_table_tensor.shape
block_table_tensor = (
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
)
block_table_tensor = (
block_table_tensor
+ torch.arange(
0,
page_size,
device=block_table_tensor.device,
dtype=block_table_tensor.dtype,
)[None, None, :]
)
block_table_tensor = block_table_tensor.view(bs, -1)
# after remapping, we assume the block size already equals to 1
max_blk_size_per_req = block_table_tensor.shape[-1]
mask = torch.arange(
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = seq_lens_device % page_size
paged_kv_last_page_len = torch.where(
paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len
)
paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device)
paged_kv_indptr = torch.cat(
[
@ -151,12 +126,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
block_table_tensor, non_blocking=True
)
block_table_tensor = self.block_table_remapping[
:num_reqs, :max_blk_size_per_req
]
self.paged_kv_indices[:num_actual_pages].copy_(
paged_kv_indices, non_blocking=True