mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 04:15:01 +08:00
CP: make correct_attn_out robust to 4‑D views and fix Triton arg binding (#26509)
Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
parent
5be7ca1b99
commit
0cd103e7cb
@ -117,14 +117,52 @@ def correct_attn_out(
|
|||||||
if ctx is None:
|
if ctx is None:
|
||||||
ctx = CPTritonContext()
|
ctx = CPTritonContext()
|
||||||
|
|
||||||
lse = torch.empty_like(lses[0])
|
# --- Normalize to 3D views ---
|
||||||
|
if out.ndim == 4 and out.shape[1] == 1:
|
||||||
|
out = out.squeeze(1)
|
||||||
|
assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}"
|
||||||
|
|
||||||
grid = (out.shape[0], out.shape[1], 1)
|
if lses.ndim == 4 and lses.shape[-1] == 1:
|
||||||
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), cp_rank)
|
lses = lses.squeeze(-1)
|
||||||
const_args = {
|
if lses.ndim == 4 and lses.shape[1] == 1:
|
||||||
"HEAD_DIM": out.shape[-1],
|
lses = lses.squeeze(1)
|
||||||
"N_ROUNDED": lses.shape[0],
|
assert lses.ndim == 3, (
|
||||||
}
|
f"expected lses [N,B,H] (optionally with a 1-sized extra dim), "
|
||||||
|
f"got {tuple(lses.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
B, H, D = out.shape
|
||||||
|
N = lses.shape[0]
|
||||||
|
|
||||||
|
# Strides after we normalized shapes to 3-D views. The kernel computes
|
||||||
|
# offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must
|
||||||
|
# have the same B/H stride layout as a slice of `lses`.
|
||||||
|
o_sB, o_sH, o_sD = out.stride()
|
||||||
|
l_sN, l_sB, l_sH = lses.stride()
|
||||||
|
|
||||||
|
# Allocate LSE with the same B/H strides as `lses` so writes land correctly
|
||||||
|
# even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze).
|
||||||
|
lse = torch.empty_strided(
|
||||||
|
(B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Kernel launch config
|
||||||
|
grid = (B, H, 1)
|
||||||
|
|
||||||
|
regular_args = (
|
||||||
|
out,
|
||||||
|
out,
|
||||||
|
lses,
|
||||||
|
lse,
|
||||||
|
o_sB,
|
||||||
|
o_sH,
|
||||||
|
o_sD,
|
||||||
|
l_sN,
|
||||||
|
l_sB,
|
||||||
|
l_sH,
|
||||||
|
cp_rank,
|
||||||
|
)
|
||||||
|
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}
|
||||||
|
|
||||||
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
|
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user