vllm/vllm/attention/ops/common.py
Michael Goin 78237e43bf
[Bugfix] Remove contiguous output req for context parallel MLA (#25414)
Signed-off-by: Michael Goin <mgoin64@gmail.com>
2025-09-22 20:26:32 -07:00

139 lines
4.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.triton_utils import tl, triton
@triton.jit
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
vlse_ptr, outputs_stride_B, outputs_stride_H,
outputs_stride_D, lses_stride_N, lses_stride_B,
lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
N_ROUNDED: tl.constexpr):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
output: [ B, H, D ]
lses : [ N, B, H ]
cp, batch, q_heads, v_head_dim
Return:
output: [ B, H, D ]
lse : [ B, H ]
"""
batch_idx = tl.program_id(axis=0).to(tl.int64)
head_idx = tl.program_id(axis=1).to(tl.int64)
d_offsets = tl.arange(0, HEAD_DIM)
num_n_offsets = tl.arange(0, N_ROUNDED)
# shape = [N]
lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
lses_stride_B + head_idx * lses_stride_H
# calc final lse
lse = tl.load(lses_ptr + lse_offsets)
lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
lse_max = tl.max(lse, axis=0)
lse -= lse_max
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log(lse_acc)
lse += lse_max
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
tl.store(vlse_ptr + lse_offsets, lse)
# shape = [D]
output_offsets = batch_idx * outputs_stride_B + \
head_idx * outputs_stride_H + \
d_offsets * outputs_stride_D
# correct output
lse_offset = lse_idx * lses_stride_N + batch_idx * \
lses_stride_B + head_idx * lses_stride_H
lse_tmp = tl.load(lses_ptr + lse_offset)
lse_finally = lse_tmp - lse
lse_finally = tl.where(
(lse_finally != lse_finally) | (lse_finally == float('inf')),
-float('inf'), lse_finally)
factor = tl.exp(lse_finally)
output = tl.load(outputs_ptr + output_offsets)
output = output * factor
tl.store(new_output_ptr + output_offsets, output)
class CPTritonContext:
""" The CPTritonContext is used to avoid recompilation of the Triton JIT.
"""
def __init__(self):
self.inner_kernel = None
def call_kernel(self, kernel, grid, *regular_args, **const_args):
if self.inner_kernel is None:
self.inner_kernel = kernel[grid](*regular_args, **const_args)
else:
self.inner_kernel[grid](*regular_args)
def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
ctx: CPTritonContext):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
output: [ B, H, D ]
lses : [ N, B, H ]
Return:
output: [ B, H, D ]
lse : [ B, H ]
"""
if ctx is None:
ctx = CPTritonContext()
lse = torch.empty_like(lses[0])
grid = (out.shape[0], out.shape[1], 1)
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
cp_rank)
const_args = {
"HEAD_DIM": out.shape[-1],
"N_ROUNDED": lses.shape[0],
}
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
**const_args)
return out, lse
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
if cp_group.world_size == 1:
return cp_attn_out
if ctx is None:
ctx = CPTritonContext()
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
dtype=cp_attn_lse.dtype,
device=cp_attn_lse.device)
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
out = cp_group.reduce_scatter(out, dim=1)
return out