From 96453cfa831340788ef72c42bc2a1a2b4496a27f Mon Sep 17 00:00:00 2001 From: TY-AMD Date: Tue, 1 Jul 2025 16:12:19 +0800 Subject: [PATCH] [BugFix][V1][ROCm] Triton MLA uses V0 backend on V1 engine (#19067) Signed-off-by: Tianyuan Wu --- .../attention/test_attention_selector.py | 6 +- .../attention/test_rocm_attention_selector.py | 6 +- vllm/platforms/rocm.py | 10 +++- vllm/v1/attention/backends/mla/common.py | 9 ++- vllm/v1/attention/backends/mla/triton_mla.py | 57 +++++++++++++++++++ 5 files changed, 78 insertions(+), 10 deletions(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index f3e64155703c2..a8ed749ba13b5 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -106,10 +106,8 @@ def test_env( block_size, False, use_mla=use_mla) - if use_v1 and name != "TRITON_MLA": - assert backend.get_name() == f"{name}_VLLM_V1" - else: - assert backend.get_name() == name + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected else: with pytest.raises(ValueError) as exc_info: get_attn_backend(16, diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index ed58880cc9e6c..34311b9ccd767 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, False, True) - assert backend.get_name() == "TRITON_MLA" + assert (backend.get_name() == "TRITON_MLA" + or backend.get_name() == "TRITON_MLA_VLLM_V1") # If attention backend is None # If use_mla is true @@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, None) backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, False, True) - assert backend.get_name() == "TRITON_MLA" + assert (backend.get_name() == "TRITON_MLA" + or backend.get_name() == "TRITON_MLA_VLLM_V1") # change the attention backend to AITER MLA m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 08d471d5a983c..ee53a76ceb6db 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -186,8 +186,14 @@ class RocmPlatform(Platform): if selected_backend == _Backend.TRITON_MLA: if block_size != 1: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 + if use_v1: + logger.info_once( + "Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") + else: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 else: raise ValueError( f" The selected backend, {selected_backend.name}," diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 1878ae74dbc6f..d45ec04472a69 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -640,7 +640,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj - self.vllm_flash_attn_version = get_flash_attn_version() # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the @@ -672,11 +671,17 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): maybe_padded_v = torch.nn.functional.pad( v, [0, q.shape[-1] - v.shape[-1]], value=0) + if is_vllm_fa: + kwargs["return_softmax_lse"] = return_softmax_lse + else: + # ROCm leverages the upstream flash_attn, which takes a parameter + # called "return_attn_probs" instead of return_softmax_lse + kwargs["return_attn_probs"] = return_softmax_lse + attn_out = self.flash_attn_varlen_func( q=q, k=k, v=maybe_padded_v, - return_softmax_lse=return_softmax_lse, softmax_scale=softmax_scale, **kwargs, ) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index e26d7909184b5..99938f22f108c 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -5,10 +5,14 @@ from typing import Any, Optional import torch +from vllm import envs from vllm.attention.backends.abstract import (AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd +from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata) @@ -68,6 +72,59 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): raise NotImplementedError( "TritonMLA V1 with FP8 KV cache not yet supported") + self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN + self.triton_fa_func = triton_attention if HAS_TRITON else None + + def _flash_attn_varlen_diff_headdims_rocm(self, + q, + k, + v, + softmax_scale=None, + **kwargs): + assert self.triton_fa_func is not None + + # Triton Attention requires a padded V + padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + # The output of triton_attention is a tuple of + # [output_tensor, encoded_softmax] where encoded_softmax is always None + output_tensor, _ = self.triton_fa_func( + q, + k, + padded_v, + None, # output + kwargs["cu_seqlens_q"], + kwargs["cu_seqlens_k"], + kwargs["max_seqlen_q"], + kwargs["max_seqlen_k"], + kwargs["causal"], + softmax_scale, + None, # bias + ) + + return output_tensor + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): + if current_platform.is_rocm() \ + and self.use_triton_flash_attn \ + and not return_softmax_lse: + return self._flash_attn_varlen_diff_headdims_rocm( + q, k, v, softmax_scale=softmax_scale, **kwargs) + else: + return super()._flash_attn_varlen_diff_headdims( + q, + k, + v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs) + def _forward_decode( self, q_nope: torch.Tensor,