From 33437bc6e7af316fa9ce6b6e559501ca45d9cd45 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Mar 2025 23:33:22 -0400 Subject: [PATCH] [BugFix] Fix nightly MLA failure (FA2 + MLA chunked prefill, i.e. V1, producing bad results) (#15492) Signed-off-by: LucasWilkinson --- vllm/attention/ops/triton_merge_attn_states.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py index 31545b607fecd..9671b933f47b9 100644 --- a/vllm/attention/ops/triton_merge_attn_states.py +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -54,6 +54,15 @@ def merge_attn_states_kernel( p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + + # FA2 and FA3 have different behavior for when the sum-exp is 0, this namely + # arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf. + # If we see an inf assume FA2 and convert inf to -inf for consistency + # and correctness. Inf generally doesn't make sense in this context outside + # of undefined-behavior/FA2-case, so I think this a safe assumption. + p_lse = float('-inf') if p_lse == float('inf') else p_lse + s_lse = float('-inf') if s_lse == float('inf') else s_lse + max_lse = tl.maximum(p_lse, s_lse) p_lse = p_lse - max_lse s_lse = s_lse - max_lse