From ce75efeecb57acb5421aeb545a95e922f3dc8b3e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 28 May 2025 04:59:39 -0400 Subject: [PATCH] [BugFix] FA2 MLA Accuracy Issue (#18807) Signed-off-by: LucasWilkinson --- csrc/attention/merge_attn_states.cu | 8 ++++++++ vllm/attention/backends/mla/common.py | 8 ++++---- vllm/v1/attention/backends/mla/common.py | 8 ++++---- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 14e5edd7e283..6bee9e4ce116 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output, const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); + TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1, + "output heads must be contiguous in memory"); + TORCH_CHECK( + prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1, + "prefix_output heads must be contiguous in memory"); + TORCH_CHECK( + suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1, + "suffix_output heads must be contiguous in memory"); float* output_lse_ptr = nullptr; if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index d48462684906..1007140ef386 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1093,10 +1093,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): if isinstance(attn_out, tuple): attn_out, *rest = attn_out - # unpad if necessary - if self._pad_v: - attn_out = attn_out[..., :v.shape[-1]] - # Remain consistent with old `flash_attn_varlen_func` where there # is only one output tensor if `return_softmax_lse` is False. if return_softmax_lse: @@ -1294,6 +1290,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): suffix_lse=suffix_lse, ) + # unpad if necessary + if self._pad_v: + output = output[..., :v.shape[-1]] + return output.flatten(start_dim=-2) @abstractmethod diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 83e181116577..1edfab26b6c1 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -653,10 +653,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if isinstance(attn_out, tuple): attn_out, lse = attn_out[0], attn_out[1] - # unpad if necessary - if self._pad_v: - attn_out = attn_out[..., :v.shape[-1]] - # Remain consistent with old `flash_attn_varlen_func` where there # is only one output tensor if `return_softmax_lse` is False. if return_softmax_lse: @@ -839,6 +835,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): suffix_lse=suffix_lse, ) + # unpad if necessary + if self._pad_v: + output = output[..., :v.shape[-1]] + return output.flatten(start_dim=-2) @abstractmethod