mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 12:27:03 +08:00
139 lines
4.4 KiB
Python
139 lines
4.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import torch
|
|
|
|
from vllm.distributed.parallel_state import GroupCoordinator
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
|
|
@triton.jit
|
|
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
|
|
vlse_ptr, outputs_stride_B, outputs_stride_H,
|
|
outputs_stride_D, lses_stride_N, lses_stride_B,
|
|
lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
|
|
N_ROUNDED: tl.constexpr):
|
|
"""
|
|
Apply the all-gathered lses to correct each local rank's attention
|
|
output. we still need perform a cross-rank reduction to obtain the
|
|
final attention output.
|
|
|
|
Args:
|
|
output: [ B, H, D ]
|
|
lses : [ N, B, H ]
|
|
cp, batch, q_heads, v_head_dim
|
|
Return:
|
|
output: [ B, H, D ]
|
|
lse : [ B, H ]
|
|
"""
|
|
batch_idx = tl.program_id(axis=0).to(tl.int64)
|
|
head_idx = tl.program_id(axis=1).to(tl.int64)
|
|
d_offsets = tl.arange(0, HEAD_DIM)
|
|
num_n_offsets = tl.arange(0, N_ROUNDED)
|
|
|
|
# shape = [N]
|
|
lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
|
|
lses_stride_B + head_idx * lses_stride_H
|
|
|
|
# calc final lse
|
|
lse = tl.load(lses_ptr + lse_offsets)
|
|
lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
|
|
lse_max = tl.max(lse, axis=0)
|
|
lse -= lse_max
|
|
lse_exp = tl.exp(lse)
|
|
lse_acc = tl.sum(lse_exp, axis=0)
|
|
lse = tl.log(lse_acc)
|
|
lse += lse_max
|
|
|
|
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
|
|
tl.store(vlse_ptr + lse_offsets, lse)
|
|
|
|
# shape = [D]
|
|
output_offsets = batch_idx * outputs_stride_B + \
|
|
head_idx * outputs_stride_H + \
|
|
d_offsets * outputs_stride_D
|
|
|
|
# correct output
|
|
lse_offset = lse_idx * lses_stride_N + batch_idx * \
|
|
lses_stride_B + head_idx * lses_stride_H
|
|
lse_tmp = tl.load(lses_ptr + lse_offset)
|
|
lse_finally = lse_tmp - lse
|
|
lse_finally = tl.where(
|
|
(lse_finally != lse_finally) | (lse_finally == float('inf')),
|
|
-float('inf'), lse_finally)
|
|
factor = tl.exp(lse_finally)
|
|
output = tl.load(outputs_ptr + output_offsets)
|
|
output = output * factor
|
|
|
|
tl.store(new_output_ptr + output_offsets, output)
|
|
|
|
|
|
class CPTritonContext:
|
|
""" The CPTritonContext is used to avoid recompilation of the Triton JIT.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.inner_kernel = None
|
|
|
|
def call_kernel(self, kernel, grid, *regular_args, **const_args):
|
|
if self.inner_kernel is None:
|
|
self.inner_kernel = kernel[grid](*regular_args, **const_args)
|
|
else:
|
|
self.inner_kernel[grid](*regular_args)
|
|
|
|
|
|
def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
|
|
ctx: CPTritonContext):
|
|
"""
|
|
Apply the all-gathered lses to correct each local rank's attention
|
|
output. we still need perform a cross-rank reduction to obtain the
|
|
final attention output.
|
|
|
|
Args:
|
|
output: [ B, H, D ]
|
|
lses : [ N, B, H ]
|
|
Return:
|
|
output: [ B, H, D ]
|
|
lse : [ B, H ]
|
|
"""
|
|
if ctx is None:
|
|
ctx = CPTritonContext()
|
|
|
|
lse = torch.empty_like(lses[0])
|
|
|
|
grid = (out.shape[0], out.shape[1], 1)
|
|
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
|
|
cp_rank)
|
|
const_args = {
|
|
"HEAD_DIM": out.shape[-1],
|
|
"N_ROUNDED": lses.shape[0],
|
|
}
|
|
|
|
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
|
|
**const_args)
|
|
return out, lse
|
|
|
|
|
|
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
|
|
cp_attn_lse: torch.Tensor,
|
|
cp_group: GroupCoordinator,
|
|
ctx: CPTritonContext = None):
|
|
"""
|
|
cp_attn_out: [ B, H, D ]
|
|
cp_attn_lse: [ B, H ]
|
|
"""
|
|
if cp_group.world_size == 1:
|
|
return cp_attn_out
|
|
|
|
if ctx is None:
|
|
ctx = CPTritonContext()
|
|
|
|
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
|
|
dtype=cp_attn_lse.dtype,
|
|
device=cp_attn_lse.device)
|
|
|
|
cp_attn_lse = cp_attn_lse.contiguous()
|
|
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
|
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
|
out = cp_group.reduce_scatter(out, dim=1)
|
|
return out
|