bugfix: correct attn output with base 2 or e (#28840)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
This commit is contained in:
Augusto Yao 2025-11-29 07:52:12 +08:00 committed by GitHub
parent 3fd1fb0b60
commit 9726e64530
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 13 deletions

View File

@ -21,6 +21,7 @@ def _correct_attn_cp_out_kernel(
lse_idx,
HEAD_DIM: tl.constexpr,
N_ROUNDED: tl.constexpr,
IS_BASE_E: tl.constexpr,
):
"""
Apply the all-gathered lses to correct each local rank's attention
@ -55,9 +56,14 @@ def _correct_attn_cp_out_kernel(
lse_max = tl.max(lse, axis=0)
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
lse -= lse_max
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log(lse_acc)
if IS_BASE_E:
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log(lse_acc)
else:
lse_exp = tl.exp2(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log2(lse_acc)
lse += lse_max
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
@ -81,7 +87,7 @@ def _correct_attn_cp_out_kernel(
-float("inf"),
lse_finally,
)
factor = tl.exp(lse_finally)
factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally)
output = tl.load(outputs_ptr + output_offsets)
output = output * factor
@ -102,7 +108,11 @@ class CPTritonContext:
def correct_attn_out(
out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext
out: torch.Tensor,
lses: torch.Tensor,
cp_rank: int,
ctx: CPTritonContext,
is_lse_base_on_e: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Correct the attention output using the all-gathered lses.
@ -163,8 +173,7 @@ def correct_attn_out(
l_sH,
cp_rank,
)
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}
const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e}
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
return out, lse
@ -174,6 +183,7 @@ def _cp_lse_common(
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
is_lse_base_on_e=True,
):
"""
cp_attn_out: [ B, H, D ]
@ -193,7 +203,13 @@ def _cp_lse_common(
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
out, lse = correct_attn_out(
cp_attn_out,
lses,
cp_group.rank_in_group,
ctx,
is_lse_base_on_e=is_lse_base_on_e,
)
return out, lse
@ -203,12 +219,15 @@ def cp_lse_ag_out_rs(
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
is_lse_base_on_e=True,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out, lse = _cp_lse_common(
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
)
out = cp_group.reduce_scatter(out, dim=1)
if return_lse:
@ -225,12 +244,15 @@ def cp_lse_ag_out_ar(
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
is_lse_base_on_e=True,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out, lse = _cp_lse_common(
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
)
out = cp_group.all_reduce(out)
if return_lse:

View File

@ -249,7 +249,11 @@ class BatchDCPPrefillWrapper:
return_lse=True,
)
output_context, lse_context = cp_lse_ag_out_rs(
output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True
output_context_tmp,
lse_context_tmp,
get_dcp_group(),
return_lse=True,
is_lse_base_on_e=False,
)
lse_context = lse_context.transpose(0, 1).contiguous()
@ -1335,7 +1339,10 @@ class FlashInferImpl(AttentionImpl):
return_lse=True,
)
output[:num_decode_tokens] = cp_lse_ag_out_rs(
output_tmp, lse, get_dcp_group()
output_tmp,
lse,
get_dcp_group(),
is_lse_base_on_e=False,
)
else:
decode_wrapper.run(

View File

@ -2057,7 +2057,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# correct dcp attn_out with lse.
if self.dcp_world_size > 1:
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
attn_out = cp_lse_ag_out_rs(
attn_out,
lse,
get_dcp_group(),
is_lse_base_on_e=not self._use_fi_prefill,
)
# v_up projection
self._v_up_proj(attn_out, out=output[:num_decode_tokens])