[ROCm][MTP] Support MTP for AITER MLA backend (#28624)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone 2025-12-16 22:10:26 +08:00 committed by GitHub
parent 104003dc77
commit 9dbbc59b15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
qo_indptr: torch.Tensor | None = None qo_indptr: torch.Tensor | None = None
# The dtype of MLA out tensor # The dtype of MLA out tensor
attn_out_dtype: torch.dtype = torch.bfloat16 attn_out_dtype: torch.dtype = torch.bfloat16
# The max query output length: int
max_qo_len: int | None = None
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO(luka, lucas): audit this as part of: # TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945 # https://github.com/vllm-project/vllm/issues/22945
_cudagraph_support: ClassVar[AttentionCGSupport] = ( _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
)
def __init__( def __init__(
self, self,
@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_reqs, dtype=torch.int32, device=device max_num_reqs, dtype=torch.int32, device=device
) )
self.qo_indptr = torch.arange( self.qo_indptr = torch.zeros(
0, max_num_reqs + 1, dtype=torch.int32, device=device max_num_reqs + 1, dtype=torch.int32, device=device
) )
def _build_decode( def _build_decode(
@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens_device.cumsum(dim=0, dtype=torch.int32), seq_lens_device.cumsum(dim=0, dtype=torch.int32),
] ]
) )
qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_qo_len = qo_len.max().item()
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0) num_actual_pages = paged_kv_indices.size(0)
@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self.paged_kv_last_page_len[num_reqs:].fill_(1) self.paged_kv_last_page_len[num_reqs:].fill_(1)
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
self.qo_indptr[: 1 + num_reqs].copy_(
query_start_loc_device, non_blocking=True
)
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
qo_indptr = self.qo_indptr[: 1 + num_reqs] qo_indptr = self.qo_indptr[: 1 + num_reqs]
else: else:
@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_last_page_len=paged_kv_last_page_len, paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr, qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device, dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_qo_len=max_qo_len,
attn_out_dtype=self.decode_attn_out_dtype, attn_out_dtype=self.decode_attn_out_dtype,
) )
@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
rocm_aiter_ops.mla_decode_fwd( rocm_aiter_ops.mla_decode_fwd(
q, q,
kv_buffer, kv_buffer,
o, o,
self.scale, self.scale,
attn_metadata.decode.qo_indptr, attn_metadata.decode.qo_indptr,
max_seqlen_qo, attn_metadata.decode.max_qo_len,
attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len, attn_metadata.decode.paged_kv_last_page_len,