diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index db79b3f5e8bc..a8f472d147a0 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -294,6 +294,8 @@ def _rocm_aiter_mla_decode_fwd_impl( kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ) -> None: from aiter.mla import mla_decode_fwd @@ -308,6 +310,8 @@ def _rocm_aiter_mla_decode_fwd_impl( max_seqlen_qo, sm_scale=sm_scale, logit_cap=logit_cap, + q_scale=q_scale, + kv_scale=kv_scale, ) @@ -322,6 +326,8 @@ def _rocm_aiter_mla_decode_fwd_fake( kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ) -> None: pass @@ -806,6 +812,8 @@ class rocm_aiter_ops: kv_indices: torch.Tensor | None = None, kv_last_page_lens: torch.Tensor | None = None, logit_cap: float = 0.0, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ): torch.ops.vllm.rocm_aiter_mla_decode_fwd( q, @@ -818,6 +826,8 @@ class rocm_aiter_ops: kv_last_page_lens, sm_scale=sm_scale, logit_cap=logit_cap, + q_scale=q_scale, + kv_scale=kv_scale, ) @staticmethod diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 56f9c7a281e7..00a0a77a1c2f 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -49,6 +49,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): paged_kv_last_page_len: torch.Tensor | None = None # The query indptr, shape : [num_decode + 1] qo_indptr: torch.Tensor | None = None + # The dtype of MLA out tensor + attn_out_dtype: torch.dtype = torch.bfloat16 class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): @@ -74,6 +76,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ) self.compilation_config = vllm_config.compilation_config + self.decode_attn_out_dtype = vllm_config.model_config.dtype # kernel block size is always 1. max_num_pages_per_req = vllm_config.model_config.max_model_len max_num_reqs = vllm_config.scheduler_config.max_num_seqs @@ -162,6 +165,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, dcp_tot_seq_lens=dcp_tot_seq_lens_device, + attn_out_dtype=self.decode_attn_out_dtype, ) return attn_metadata @@ -242,7 +246,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): assert isinstance(q, torch.Tensor) B = q.shape[0] o = torch.zeros( - B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + B, + self.num_heads, + self.kv_lora_rank, + dtype=attn_metadata.decode.attn_out_dtype, + device=q.device, ) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) @@ -260,6 +268,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len, + q_scale=layer._q_scale, + kv_scale=layer._k_scale, ) return o, None