diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 6bee9e4ce1166..229d9862fb670 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -46,6 +46,32 @@ __global__ void merge_attn_states_kernel( s_lse = std::isinf(s_lse) ? -std::numeric_limits::infinity() : s_lse; const float max_lse = fmaxf(p_lse, s_lse); + + /* In certain edge cases, MLA can produce p_lse = s_lse = -inf; + continuing the pipeline then yields NaN. Root cause: with chunked prefill + a batch may be split into two chunks; if a request in that batch has no + prefix hit, every LSE entry for that request’s position is -inf, and at + this moment we merge cross-attention at first. For now we simply emit + prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix + this problem. + */ + if (std::isinf(max_lse)) { + if (pack_offset < head_size) { + // Pack 128b load + pack_128b_t p_out_pack = reinterpret_cast( + prefix_head_ptr)[pack_offset / pack_size]; + + // Pack 128b storage + reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = + p_out_pack; + } + // We only need to write to output_lse once per head. + if (output_lse != nullptr && pack_idx == 0) { + output_lse[head_idx * num_tokens + token_idx] = max_lse; + } + return; + } + p_lse = p_lse - max_lse; s_lse = s_lse - max_lse; const float p_se = expf(p_lse);