mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 21:21:22 +08:00
[ROCm][MTP] Support MTP for AITER MLA backend (#28624)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
parent
104003dc77
commit
9dbbc59b15
@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
|
|||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata,
|
MLACommonMetadata,
|
||||||
MLACommonMetadataBuilder,
|
MLACommonMetadataBuilder,
|
||||||
|
QueryLenSupport,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
|||||||
qo_indptr: torch.Tensor | None = None
|
qo_indptr: torch.Tensor | None = None
|
||||||
# The dtype of MLA out tensor
|
# The dtype of MLA out tensor
|
||||||
attn_out_dtype: torch.dtype = torch.bfloat16
|
attn_out_dtype: torch.dtype = torch.bfloat16
|
||||||
|
# The max query output length: int
|
||||||
|
max_qo_len: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||||
@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
|||||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||||
# TODO(luka, lucas): audit this as part of:
|
# TODO(luka, lucas): audit this as part of:
|
||||||
# https://github.com/vllm-project/vllm/issues/22945
|
# https://github.com/vllm-project/vllm/issues/22945
|
||||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
max_num_reqs, dtype=torch.int32, device=device
|
max_num_reqs, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qo_indptr = torch.arange(
|
self.qo_indptr = torch.zeros(
|
||||||
0, max_num_reqs + 1, dtype=torch.int32, device=device
|
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_decode(
|
def _build_decode(
|
||||||
@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
|
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||||
|
max_qo_len = qo_len.max().item()
|
||||||
|
|
||||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||||
num_actual_pages = paged_kv_indices.size(0)
|
num_actual_pages = paged_kv_indices.size(0)
|
||||||
@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
||||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||||
|
|
||||||
|
self.qo_indptr[: 1 + num_reqs].copy_(
|
||||||
|
query_start_loc_device, non_blocking=True
|
||||||
|
)
|
||||||
|
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
|
||||||
qo_indptr = self.qo_indptr[: 1 + num_reqs]
|
qo_indptr = self.qo_indptr[: 1 + num_reqs]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||||
qo_indptr=qo_indptr,
|
qo_indptr=qo_indptr,
|
||||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||||
|
max_qo_len=max_qo_len,
|
||||||
attn_out_dtype=self.decode_attn_out_dtype,
|
attn_out_dtype=self.decode_attn_out_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
|
|
||||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||||
|
|
||||||
# max_seqlen_qo must be 1 except for MTP
|
|
||||||
# TODO: Find the best value for MTP
|
|
||||||
max_seqlen_qo = 1
|
|
||||||
rocm_aiter_ops.mla_decode_fwd(
|
rocm_aiter_ops.mla_decode_fwd(
|
||||||
q,
|
q,
|
||||||
kv_buffer,
|
kv_buffer,
|
||||||
o,
|
o,
|
||||||
self.scale,
|
self.scale,
|
||||||
attn_metadata.decode.qo_indptr,
|
attn_metadata.decode.qo_indptr,
|
||||||
max_seqlen_qo,
|
attn_metadata.decode.max_qo_len,
|
||||||
attn_metadata.decode.paged_kv_indptr,
|
attn_metadata.decode.paged_kv_indptr,
|
||||||
attn_metadata.decode.paged_kv_indices,
|
attn_metadata.decode.paged_kv_indices,
|
||||||
attn_metadata.decode.paged_kv_last_page_len,
|
attn_metadata.decode.paged_kv_last_page_len,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user