[ROCm][MLA] Support block-size > 1 for AITER MLA backend (#27224)

Signed-off-by: ganyi <ygan@amd.com>
Co-authored-by: wuhuikx <hattie.wu@amd.com>
This commit is contained in:
Pleaplusone 2025-11-05 23:43:02 +08:00 committed by GitHub
parent 80c9275348
commit 6cae1e5332
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 24 deletions

View File

@ -104,13 +104,6 @@ def test_env(
16, torch.float16, None, block_size, use_mla=use_mla
)
assert f"The selected backend, {name}" in str(exc_info.value)
elif name == "ROCM_AITER_MLA" and block_size != 1:
# ROCM_AITER_MLA only supports block_size == 1
with pytest.raises(ValueError) as exc_info:
get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
assert f"The selected backend, {name}" in str(exc_info.value)
else:
# Valid backend-block_size combination
backend = get_attn_backend(

View File

@ -252,16 +252,9 @@ class RocmPlatform(Platform):
f"does not support block size {block_size}."
)
if selected_backend == _Backend.ROCM_AITER_MLA:
if block_size == 1:
logger.info("Using AITER MLA backend.")
return (
"vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
)
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)"
)
logger.info("Using AITER MLA backend.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend."

View File

@ -78,9 +78,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
super().__init__(
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
)
assert self.kv_cache_spec.block_size == 1, (
"AITER MLAonly supports block size 1."
)
self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(
@ -94,6 +91,11 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.block_table_remapping = torch.zeros(
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
dtype=torch.int32,
device=device,
)
self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
@ -119,13 +121,29 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
device = self.device
num_reqs = seq_lens_device.size(0)
bs, _ = block_table_tensor.shape
block_table_tensor = (
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
)
block_table_tensor = (
block_table_tensor
+ torch.arange(
0,
page_size,
device=block_table_tensor.device,
dtype=block_table_tensor.dtype,
)[None, None, :]
)
block_table_tensor = block_table_tensor.view(bs, -1)
# after remapping, we assume the block size already equals to 1
max_blk_size_per_req = block_table_tensor.shape[-1]
mask = torch.arange(
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
).unsqueeze(0) < block_table_bounds.unsqueeze(1)
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = seq_lens_device % page_size
@ -135,13 +153,19 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_indptr = torch.cat(
[
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32),
torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
]
)
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
block_table_tensor, non_blocking=True
)
block_table_tensor = self.block_table_remapping[
:num_reqs, :max_blk_size_per_req
]
self.paged_kv_indices[:num_actual_pages].copy_(
paged_kv_indices, non_blocking=True