mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 12:04:28 +08:00
bugfix: correct attn output with base 2 or e (#28840)
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
This commit is contained in:
parent
3fd1fb0b60
commit
9726e64530
@ -21,6 +21,7 @@ def _correct_attn_cp_out_kernel(
|
||||
lse_idx,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
N_ROUNDED: tl.constexpr,
|
||||
IS_BASE_E: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Apply the all-gathered lses to correct each local rank's attention
|
||||
@ -55,9 +56,14 @@ def _correct_attn_cp_out_kernel(
|
||||
lse_max = tl.max(lse, axis=0)
|
||||
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
|
||||
lse -= lse_max
|
||||
lse_exp = tl.exp(lse)
|
||||
lse_acc = tl.sum(lse_exp, axis=0)
|
||||
lse = tl.log(lse_acc)
|
||||
if IS_BASE_E:
|
||||
lse_exp = tl.exp(lse)
|
||||
lse_acc = tl.sum(lse_exp, axis=0)
|
||||
lse = tl.log(lse_acc)
|
||||
else:
|
||||
lse_exp = tl.exp2(lse)
|
||||
lse_acc = tl.sum(lse_exp, axis=0)
|
||||
lse = tl.log2(lse_acc)
|
||||
lse += lse_max
|
||||
|
||||
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
|
||||
@ -81,7 +87,7 @@ def _correct_attn_cp_out_kernel(
|
||||
-float("inf"),
|
||||
lse_finally,
|
||||
)
|
||||
factor = tl.exp(lse_finally)
|
||||
factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally)
|
||||
output = tl.load(outputs_ptr + output_offsets)
|
||||
output = output * factor
|
||||
|
||||
@ -102,7 +108,11 @@ class CPTritonContext:
|
||||
|
||||
|
||||
def correct_attn_out(
|
||||
out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext
|
||||
out: torch.Tensor,
|
||||
lses: torch.Tensor,
|
||||
cp_rank: int,
|
||||
ctx: CPTritonContext,
|
||||
is_lse_base_on_e: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Correct the attention output using the all-gathered lses.
|
||||
|
||||
@ -163,8 +173,7 @@ def correct_attn_out(
|
||||
l_sH,
|
||||
cp_rank,
|
||||
)
|
||||
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}
|
||||
|
||||
const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e}
|
||||
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
|
||||
return out, lse
|
||||
|
||||
@ -174,6 +183,7 @@ def _cp_lse_common(
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext | None = None,
|
||||
is_lse_base_on_e=True,
|
||||
):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
@ -193,7 +203,13 @@ def _cp_lse_common(
|
||||
|
||||
cp_attn_lse = cp_attn_lse.contiguous()
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
out, lse = correct_attn_out(
|
||||
cp_attn_out,
|
||||
lses,
|
||||
cp_group.rank_in_group,
|
||||
ctx,
|
||||
is_lse_base_on_e=is_lse_base_on_e,
|
||||
)
|
||||
return out, lse
|
||||
|
||||
|
||||
@ -203,12 +219,15 @@ def cp_lse_ag_out_rs(
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext | None = None,
|
||||
return_lse: bool = False,
|
||||
is_lse_base_on_e=True,
|
||||
):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
cp_attn_lse: [ B, H ]
|
||||
"""
|
||||
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
|
||||
out, lse = _cp_lse_common(
|
||||
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
|
||||
)
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
|
||||
if return_lse:
|
||||
@ -225,12 +244,15 @@ def cp_lse_ag_out_ar(
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext | None = None,
|
||||
return_lse: bool = False,
|
||||
is_lse_base_on_e=True,
|
||||
):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
cp_attn_lse: [ B, H ]
|
||||
"""
|
||||
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
|
||||
out, lse = _cp_lse_common(
|
||||
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
|
||||
)
|
||||
out = cp_group.all_reduce(out)
|
||||
|
||||
if return_lse:
|
||||
|
||||
@ -249,7 +249,11 @@ class BatchDCPPrefillWrapper:
|
||||
return_lse=True,
|
||||
)
|
||||
output_context, lse_context = cp_lse_ag_out_rs(
|
||||
output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True
|
||||
output_context_tmp,
|
||||
lse_context_tmp,
|
||||
get_dcp_group(),
|
||||
return_lse=True,
|
||||
is_lse_base_on_e=False,
|
||||
)
|
||||
lse_context = lse_context.transpose(0, 1).contiguous()
|
||||
|
||||
@ -1335,7 +1339,10 @@ class FlashInferImpl(AttentionImpl):
|
||||
return_lse=True,
|
||||
)
|
||||
output[:num_decode_tokens] = cp_lse_ag_out_rs(
|
||||
output_tmp, lse, get_dcp_group()
|
||||
output_tmp,
|
||||
lse,
|
||||
get_dcp_group(),
|
||||
is_lse_base_on_e=False,
|
||||
)
|
||||
else:
|
||||
decode_wrapper.run(
|
||||
|
||||
@ -2057,7 +2057,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
# correct dcp attn_out with lse.
|
||||
if self.dcp_world_size > 1:
|
||||
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
||||
attn_out = cp_lse_ag_out_rs(
|
||||
attn_out,
|
||||
lse,
|
||||
get_dcp_group(),
|
||||
is_lse_base_on_e=not self._use_fi_prefill,
|
||||
)
|
||||
|
||||
# v_up projection
|
||||
self._v_up_proj(attn_out, out=output[:num_decode_tokens])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user