diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 51899b023591..91e1cad01f4f 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -46,6 +46,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" + ROCM_AITER_TRITON_MLA = ( + "vllm.v1.attention.backends.mla.aiter_triton_mla.AiterTritonMLABackend" + ) ROCM_AITER_FA = ( "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" ) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index bb116792fed5..f07f068a9249 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -234,7 +234,6 @@ class RocmPlatform(Platform): if rocm_aiter_ops.is_mla_enabled() or block_size == 1 else AttentionBackendEnum.TRITON_MLA ) - if selected_backend == AttentionBackendEnum.TRITON_MLA: if block_size != 1: logger.info_once("Using Triton MLA backend.") @@ -246,6 +245,9 @@ class RocmPlatform(Platform): if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA: logger.info("Using AITER MLA backend.") return AttentionBackendEnum.ROCM_AITER_MLA.get_path() + if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA: + logger.info("Using AITER TRITON MLA backend.") + return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path() raise ValueError( f" The selected backend, {selected_backend.name}," diff --git a/vllm/v1/attention/backends/mla/aiter_triton_mla.py b/vllm/v1/attention/backends/mla/aiter_triton_mla.py new file mode 100644 index 000000000000..8a92152a0ca5 --- /dev/null +++ b/vllm/v1/attention/backends/mla/aiter_triton_mla.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.v1.attention.backends.mla.common import MLACommonBackend +from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( + AiterMLAImpl, + AiterMLAMetadataBuilder, +) + + +class AiterTritonMLABackend(MLACommonBackend): + @staticmethod + def get_name() -> str: + return "AITER_TRITON_MLA" + + @staticmethod + def get_impl_cls() -> type["AiterTritonMLAImpl"]: + return AiterTritonMLAImpl + + @staticmethod + def get_builder_cls() -> type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + +class AiterTritonMLAImpl(AiterMLAImpl): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + from aiter.ops.triton.mha import flash_attn_varlen_func + + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): + result = self.flash_attn_varlen_func( + q, + k, + v, + softmax_scale=softmax_scale, + return_lse=return_softmax_lse, + **kwargs, + ) + # Transpose the LSE if Triton MHA is used: + # (q.shape[0], num_q_heads) to (num_q_heads, q.shape[0]) + if type(result) is tuple and return_softmax_lse: + output, lse = result + lse = lse.T.contiguous() + return (output, lse) + return result