[Bugfix][Kernel] fix merge attn states when both prefix and suffix are empty (#28181)

Signed-off-by: courage17340 <courage17340@163.com>
This commit is contained in:
courage17340 2025-11-06 17:52:13 +08:00 committed by GitHub
parent c3ee80a01a
commit 981cadb35c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -46,6 +46,32 @@ __global__ void merge_attn_states_kernel(
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
const float max_lse = fmaxf(p_lse, s_lse);
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
continuing the pipeline then yields NaN. Root cause: with chunked prefill
a batch may be split into two chunks; if a request in that batch has no
prefix hit, every LSE entry for that requests position is -inf, and at
this moment we merge cross-attention at first. For now we simply emit
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
this problem.
*/
if (std::isinf(max_lse)) {
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_head_ptr)[pack_offset / pack_size];
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
p_out_pack;
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
output_lse[head_idx * num_tokens + token_idx] = max_lse;
}
return;
}
p_lse = p_lse - max_lse;
s_lse = s_lse - max_lse;
const float p_se = expf(p_lse);