mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 12:45:38 +08:00
[ROCm][MTP] Support MTP for AITER MLA backend (#28624)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
parent
104003dc77
commit
9dbbc59b15
@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
qo_indptr: torch.Tensor | None = None
|
||||
# The dtype of MLA out tensor
|
||||
attn_out_dtype: torch.dtype = torch.bfloat16
|
||||
# The max query output length: int
|
||||
max_qo_len: int | None = None
|
||||
|
||||
|
||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
# TODO(luka, lucas): audit this as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
self.qo_indptr = torch.arange(
|
||||
0, max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
self.qo_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
def _build_decode(
|
||||
@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_qo_len = qo_len.max().item()
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||
|
||||
self.qo_indptr[: 1 + num_reqs].copy_(
|
||||
query_start_loc_device, non_blocking=True
|
||||
)
|
||||
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
|
||||
qo_indptr = self.qo_indptr[: 1 + num_reqs]
|
||||
|
||||
else:
|
||||
@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_qo_len=max_qo_len,
|
||||
attn_out_dtype=self.decode_attn_out_dtype,
|
||||
)
|
||||
|
||||
@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
# max_seqlen_qo must be 1 except for MTP
|
||||
# TODO: Find the best value for MTP
|
||||
max_seqlen_qo = 1
|
||||
rocm_aiter_ops.mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer,
|
||||
o,
|
||||
self.scale,
|
||||
attn_metadata.decode.qo_indptr,
|
||||
max_seqlen_qo,
|
||||
attn_metadata.decode.max_qo_len,
|
||||
attn_metadata.decode.paged_kv_indptr,
|
||||
attn_metadata.decode.paged_kv_indices,
|
||||
attn_metadata.decode.paged_kv_last_page_len,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user