From 0cd103e7cbf0315c69434870c4973ded2c5d99e5 Mon Sep 17 00:00:00 2001 From: Huamin Li <3ericli@gmail.com> Date: Sat, 11 Oct 2025 13:50:57 -0700 Subject: [PATCH] =?UTF-8?q?CP:=20make=20correct=5Fattn=5Fout=20robust=20to?= =?UTF-8?q?=204=E2=80=91D=20views=20and=20fix=20Triton=20arg=20binding=20(?= =?UTF-8?q?#26509)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Huamin Li <3ericli@gmail.com> --- vllm/attention/ops/common.py | 52 +++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 097fbae68cda..1234e1b2e46a 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -117,14 +117,52 @@ def correct_attn_out( if ctx is None: ctx = CPTritonContext() - lse = torch.empty_like(lses[0]) + # --- Normalize to 3D views --- + if out.ndim == 4 and out.shape[1] == 1: + out = out.squeeze(1) + assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}" - grid = (out.shape[0], out.shape[1], 1) - regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), cp_rank) - const_args = { - "HEAD_DIM": out.shape[-1], - "N_ROUNDED": lses.shape[0], - } + if lses.ndim == 4 and lses.shape[-1] == 1: + lses = lses.squeeze(-1) + if lses.ndim == 4 and lses.shape[1] == 1: + lses = lses.squeeze(1) + assert lses.ndim == 3, ( + f"expected lses [N,B,H] (optionally with a 1-sized extra dim), " + f"got {tuple(lses.shape)}" + ) + + B, H, D = out.shape + N = lses.shape[0] + + # Strides after we normalized shapes to 3-D views. The kernel computes + # offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must + # have the same B/H stride layout as a slice of `lses`. + o_sB, o_sH, o_sD = out.stride() + l_sN, l_sB, l_sH = lses.stride() + + # Allocate LSE with the same B/H strides as `lses` so writes land correctly + # even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze). + lse = torch.empty_strided( + (B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype + ) + + # Kernel launch config + grid = (B, H, 1) + + regular_args = ( + out, + out, + lses, + lse, + o_sB, + o_sH, + o_sD, + l_sN, + l_sB, + l_sH, + cp_rank, + ) + const_args = {"HEAD_DIM": D, "N_ROUNDED": N} ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args) return out, lse