mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 04:29:08 +08:00
bugfix: correct attn output with base 2 or e (#28840)
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
This commit is contained in:
parent
3fd1fb0b60
commit
9726e64530
@ -21,6 +21,7 @@ def _correct_attn_cp_out_kernel(
|
|||||||
lse_idx,
|
lse_idx,
|
||||||
HEAD_DIM: tl.constexpr,
|
HEAD_DIM: tl.constexpr,
|
||||||
N_ROUNDED: tl.constexpr,
|
N_ROUNDED: tl.constexpr,
|
||||||
|
IS_BASE_E: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Apply the all-gathered lses to correct each local rank's attention
|
Apply the all-gathered lses to correct each local rank's attention
|
||||||
@ -55,9 +56,14 @@ def _correct_attn_cp_out_kernel(
|
|||||||
lse_max = tl.max(lse, axis=0)
|
lse_max = tl.max(lse, axis=0)
|
||||||
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
|
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
|
||||||
lse -= lse_max
|
lse -= lse_max
|
||||||
lse_exp = tl.exp(lse)
|
if IS_BASE_E:
|
||||||
lse_acc = tl.sum(lse_exp, axis=0)
|
lse_exp = tl.exp(lse)
|
||||||
lse = tl.log(lse_acc)
|
lse_acc = tl.sum(lse_exp, axis=0)
|
||||||
|
lse = tl.log(lse_acc)
|
||||||
|
else:
|
||||||
|
lse_exp = tl.exp2(lse)
|
||||||
|
lse_acc = tl.sum(lse_exp, axis=0)
|
||||||
|
lse = tl.log2(lse_acc)
|
||||||
lse += lse_max
|
lse += lse_max
|
||||||
|
|
||||||
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
|
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
|
||||||
@ -81,7 +87,7 @@ def _correct_attn_cp_out_kernel(
|
|||||||
-float("inf"),
|
-float("inf"),
|
||||||
lse_finally,
|
lse_finally,
|
||||||
)
|
)
|
||||||
factor = tl.exp(lse_finally)
|
factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally)
|
||||||
output = tl.load(outputs_ptr + output_offsets)
|
output = tl.load(outputs_ptr + output_offsets)
|
||||||
output = output * factor
|
output = output * factor
|
||||||
|
|
||||||
@ -102,7 +108,11 @@ class CPTritonContext:
|
|||||||
|
|
||||||
|
|
||||||
def correct_attn_out(
|
def correct_attn_out(
|
||||||
out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext
|
out: torch.Tensor,
|
||||||
|
lses: torch.Tensor,
|
||||||
|
cp_rank: int,
|
||||||
|
ctx: CPTritonContext,
|
||||||
|
is_lse_base_on_e: bool = True,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Correct the attention output using the all-gathered lses.
|
"""Correct the attention output using the all-gathered lses.
|
||||||
|
|
||||||
@ -163,8 +173,7 @@ def correct_attn_out(
|
|||||||
l_sH,
|
l_sH,
|
||||||
cp_rank,
|
cp_rank,
|
||||||
)
|
)
|
||||||
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}
|
const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e}
|
||||||
|
|
||||||
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
|
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
@ -174,6 +183,7 @@ def _cp_lse_common(
|
|||||||
cp_attn_lse: torch.Tensor,
|
cp_attn_lse: torch.Tensor,
|
||||||
cp_group: GroupCoordinator,
|
cp_group: GroupCoordinator,
|
||||||
ctx: CPTritonContext | None = None,
|
ctx: CPTritonContext | None = None,
|
||||||
|
is_lse_base_on_e=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
cp_attn_out: [ B, H, D ]
|
cp_attn_out: [ B, H, D ]
|
||||||
@ -193,7 +203,13 @@ def _cp_lse_common(
|
|||||||
|
|
||||||
cp_attn_lse = cp_attn_lse.contiguous()
|
cp_attn_lse = cp_attn_lse.contiguous()
|
||||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||||
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
out, lse = correct_attn_out(
|
||||||
|
cp_attn_out,
|
||||||
|
lses,
|
||||||
|
cp_group.rank_in_group,
|
||||||
|
ctx,
|
||||||
|
is_lse_base_on_e=is_lse_base_on_e,
|
||||||
|
)
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
@ -203,12 +219,15 @@ def cp_lse_ag_out_rs(
|
|||||||
cp_group: GroupCoordinator,
|
cp_group: GroupCoordinator,
|
||||||
ctx: CPTritonContext | None = None,
|
ctx: CPTritonContext | None = None,
|
||||||
return_lse: bool = False,
|
return_lse: bool = False,
|
||||||
|
is_lse_base_on_e=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
cp_attn_out: [ B, H, D ]
|
cp_attn_out: [ B, H, D ]
|
||||||
cp_attn_lse: [ B, H ]
|
cp_attn_lse: [ B, H ]
|
||||||
"""
|
"""
|
||||||
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
|
out, lse = _cp_lse_common(
|
||||||
|
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
|
||||||
|
)
|
||||||
out = cp_group.reduce_scatter(out, dim=1)
|
out = cp_group.reduce_scatter(out, dim=1)
|
||||||
|
|
||||||
if return_lse:
|
if return_lse:
|
||||||
@ -225,12 +244,15 @@ def cp_lse_ag_out_ar(
|
|||||||
cp_group: GroupCoordinator,
|
cp_group: GroupCoordinator,
|
||||||
ctx: CPTritonContext | None = None,
|
ctx: CPTritonContext | None = None,
|
||||||
return_lse: bool = False,
|
return_lse: bool = False,
|
||||||
|
is_lse_base_on_e=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
cp_attn_out: [ B, H, D ]
|
cp_attn_out: [ B, H, D ]
|
||||||
cp_attn_lse: [ B, H ]
|
cp_attn_lse: [ B, H ]
|
||||||
"""
|
"""
|
||||||
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
|
out, lse = _cp_lse_common(
|
||||||
|
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
|
||||||
|
)
|
||||||
out = cp_group.all_reduce(out)
|
out = cp_group.all_reduce(out)
|
||||||
|
|
||||||
if return_lse:
|
if return_lse:
|
||||||
|
|||||||
@ -249,7 +249,11 @@ class BatchDCPPrefillWrapper:
|
|||||||
return_lse=True,
|
return_lse=True,
|
||||||
)
|
)
|
||||||
output_context, lse_context = cp_lse_ag_out_rs(
|
output_context, lse_context = cp_lse_ag_out_rs(
|
||||||
output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True
|
output_context_tmp,
|
||||||
|
lse_context_tmp,
|
||||||
|
get_dcp_group(),
|
||||||
|
return_lse=True,
|
||||||
|
is_lse_base_on_e=False,
|
||||||
)
|
)
|
||||||
lse_context = lse_context.transpose(0, 1).contiguous()
|
lse_context = lse_context.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
@ -1335,7 +1339,10 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
return_lse=True,
|
return_lse=True,
|
||||||
)
|
)
|
||||||
output[:num_decode_tokens] = cp_lse_ag_out_rs(
|
output[:num_decode_tokens] = cp_lse_ag_out_rs(
|
||||||
output_tmp, lse, get_dcp_group()
|
output_tmp,
|
||||||
|
lse,
|
||||||
|
get_dcp_group(),
|
||||||
|
is_lse_base_on_e=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
decode_wrapper.run(
|
decode_wrapper.run(
|
||||||
|
|||||||
@ -2057,7 +2057,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
|
|
||||||
# correct dcp attn_out with lse.
|
# correct dcp attn_out with lse.
|
||||||
if self.dcp_world_size > 1:
|
if self.dcp_world_size > 1:
|
||||||
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
attn_out = cp_lse_ag_out_rs(
|
||||||
|
attn_out,
|
||||||
|
lse,
|
||||||
|
get_dcp_group(),
|
||||||
|
is_lse_base_on_e=not self._use_fi_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
# v_up projection
|
# v_up projection
|
||||||
self._v_up_proj(attn_out, out=output[:num_decode_tokens])
|
self._v_up_proj(attn_out, out=output[:num_decode_tokens])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user