diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 097fbae68cda..1234e1b2e46a 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -117,14 +117,52 @@ def correct_attn_out( if ctx is None: ctx = CPTritonContext() - lse = torch.empty_like(lses[0]) + # --- Normalize to 3D views --- + if out.ndim == 4 and out.shape[1] == 1: + out = out.squeeze(1) + assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}" - grid = (out.shape[0], out.shape[1], 1) - regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), cp_rank) - const_args = { - "HEAD_DIM": out.shape[-1], - "N_ROUNDED": lses.shape[0], - } + if lses.ndim == 4 and lses.shape[-1] == 1: + lses = lses.squeeze(-1) + if lses.ndim == 4 and lses.shape[1] == 1: + lses = lses.squeeze(1) + assert lses.ndim == 3, ( + f"expected lses [N,B,H] (optionally with a 1-sized extra dim), " + f"got {tuple(lses.shape)}" + ) + + B, H, D = out.shape + N = lses.shape[0] + + # Strides after we normalized shapes to 3-D views. The kernel computes + # offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must + # have the same B/H stride layout as a slice of `lses`. + o_sB, o_sH, o_sD = out.stride() + l_sN, l_sB, l_sH = lses.stride() + + # Allocate LSE with the same B/H strides as `lses` so writes land correctly + # even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze). + lse = torch.empty_strided( + (B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype + ) + + # Kernel launch config + grid = (B, H, 1) + + regular_args = ( + out, + out, + lses, + lse, + o_sB, + o_sH, + o_sD, + l_sN, + l_sB, + l_sH, + cp_rank, + ) + const_args = {"HEAD_DIM": D, "N_ROUNDED": N} ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args) return out, lse