diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 00a0a77a1c2f7..589d6ef2f6348 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import ( MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, + QueryLenSupport, ) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec @@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): qo_indptr: torch.Tensor | None = None # The dtype of MLA out tensor attn_out_dtype: torch.dtype = torch.bfloat16 + # The max query output length: int + max_qo_len: int | None = None class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): @@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 - _cudagraph_support: ClassVar[AttentionCGSupport] = ( - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - ) + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM def __init__( self, @@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): max_num_reqs, dtype=torch.int32, device=device ) - self.qo_indptr = torch.arange( - 0, max_num_reqs + 1, dtype=torch.int32, device=device + self.qo_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device ) def _build_decode( @@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): seq_lens_device.cumsum(dim=0, dtype=torch.int32), ] ) + qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_qo_len = qo_len.max().item() if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): num_actual_pages = paged_kv_indices.size(0) @@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): self.paged_kv_last_page_len[num_reqs:].fill_(1) paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + self.qo_indptr[: 1 + num_reqs].copy_( + query_start_loc_device, non_blocking=True + ) + self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1] qo_indptr = self.qo_indptr[: 1 + num_reqs] else: @@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, dcp_tot_seq_lens=dcp_tot_seq_lens_device, + max_qo_len=max_qo_len, attn_out_dtype=self.decode_attn_out_dtype, ) @@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - # max_seqlen_qo must be 1 except for MTP - # TODO: Find the best value for MTP - max_seqlen_qo = 1 rocm_aiter_ops.mla_decode_fwd( q, kv_buffer, o, self.scale, attn_metadata.decode.qo_indptr, - max_seqlen_qo, + attn_metadata.decode.max_qo_len, attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len,