[ROCm][MLA] Support block-size > 1 for AITER MLA backend (#27224)

Signed-off-by: ganyi <ygan@amd.com>
Co-authored-by: wuhuikx <hattie.wu@amd.com>
This commit is contained in:
Pleaplusone 2025-11-05 23:43:02 +08:00 committed by GitHub
parent 80c9275348
commit 6cae1e5332
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 24 deletions

View File

@ -104,13 +104,6 @@ def test_env(
16, torch.float16, None, block_size, use_mla=use_mla 16, torch.float16, None, block_size, use_mla=use_mla
) )
assert f"The selected backend, {name}" in str(exc_info.value) assert f"The selected backend, {name}" in str(exc_info.value)
elif name == "ROCM_AITER_MLA" and block_size != 1:
# ROCM_AITER_MLA only supports block_size == 1
with pytest.raises(ValueError) as exc_info:
get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
assert f"The selected backend, {name}" in str(exc_info.value)
else: else:
# Valid backend-block_size combination # Valid backend-block_size combination
backend = get_attn_backend( backend = get_attn_backend(

View File

@ -252,16 +252,9 @@ class RocmPlatform(Platform):
f"does not support block size {block_size}." f"does not support block size {block_size}."
) )
if selected_backend == _Backend.ROCM_AITER_MLA: if selected_backend == _Backend.ROCM_AITER_MLA:
if block_size == 1: logger.info("Using AITER MLA backend.")
logger.info("Using AITER MLA backend.") return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
return (
"vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
)
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)"
)
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend." f"is not MLA type while requested for MLA backend."

View File

@ -78,9 +78,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
super().__init__( super().__init__(
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
) )
assert self.kv_cache_spec.block_size == 1, (
"AITER MLAonly supports block size 1."
)
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv( max_num_pages_per_req = cdiv(
@ -94,6 +91,11 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# so we can only use the persistent buffer if a cudagraph is actually # so we can only use the persistent buffer if a cudagraph is actually
# being used. # being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.block_table_remapping = torch.zeros(
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
dtype=torch.int32,
device=device,
)
self.paged_kv_indptr = torch.zeros( self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device max_num_reqs + 1, dtype=torch.int32, device=device
) )
@ -119,13 +121,29 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
dcp_tot_seq_lens_device: torch.Tensor | None, dcp_tot_seq_lens_device: torch.Tensor | None,
) -> AiterMLADecodeMetadata: ) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
device = self.device device = self.device
num_reqs = seq_lens_device.size(0) num_reqs = seq_lens_device.size(0)
bs, _ = block_table_tensor.shape
block_table_tensor = (
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
)
block_table_tensor = (
block_table_tensor
+ torch.arange(
0,
page_size,
device=block_table_tensor.device,
dtype=block_table_tensor.dtype,
)[None, None, :]
)
block_table_tensor = block_table_tensor.view(bs, -1)
# after remapping, we assume the block size already equals to 1
max_blk_size_per_req = block_table_tensor.shape[-1]
mask = torch.arange( mask = torch.arange(
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
).unsqueeze(0) < block_table_bounds.unsqueeze(1) ).unsqueeze(0) < seq_lens_device.unsqueeze(1)
paged_kv_indices = block_table_tensor[mask] paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = seq_lens_device % page_size paged_kv_last_page_len = seq_lens_device % page_size
@ -135,13 +153,19 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_indptr = torch.cat( paged_kv_indptr = torch.cat(
[ [
torch.zeros(1, dtype=block_table_bounds.dtype, device=device), torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32), seq_lens_device.cumsum(dim=0, dtype=torch.int32),
] ]
) )
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0) num_actual_pages = paged_kv_indices.size(0)
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
block_table_tensor, non_blocking=True
)
block_table_tensor = self.block_table_remapping[
:num_reqs, :max_blk_size_per_req
]
self.paged_kv_indices[:num_actual_pages].copy_( self.paged_kv_indices[:num_actual_pages].copy_(
paged_kv_indices, non_blocking=True paged_kv_indices, non_blocking=True