From bcb6f5947f8acac22f5d1d5fc92a91b06fd57e77 Mon Sep 17 00:00:00 2001 From: Dazhi Jiang Date: Mon, 8 Dec 2025 15:12:42 +0800 Subject: [PATCH] [Perf] Remove sync point in vit torch sdpa attn backend (#30232) Signed-off-by: Dazhi Jiang --- vllm/attention/ops/vit_attn_wrappers.py | 12 ++++++------ vllm/model_executor/models/ernie45_vl.py | 12 ++++++------ vllm/model_executor/models/glm4_1v.py | 12 ++++++------ vllm/model_executor/models/qwen2_vl.py | 12 ++++++------ 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index d9f15f1e42858..9036c2b801949 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -93,12 +93,12 @@ def torch_sdpa_wrapper( cu_seqlens: torch.Tensor, ) -> torch.Tensor: outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] + + lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + q_chunks = torch.split(q, lens, dim=1) + k_chunks = torch.split(k, lens, dim=1) + v_chunks = torch.split(v, lens, dim=1) + for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): q_i, k_i, v_i = ( einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] ) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 3305b6a0e58f0..053d260cc09b2 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -289,12 +289,12 @@ class Ernie4_5_VisionAttention(nn.Module): elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] + + lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + q_chunks = torch.split(q, lens, dim=1) + k_chunks = torch.split(k, lens, dim=1) + v_chunks = torch.split(v, lens, dim=1) + for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): q_i, k_i, v_i = ( rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] ) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 39a837b789bed..741edfdda3e2c 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -377,12 +377,12 @@ class Glm4vVisionAttention(nn.Module): elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] + + lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + q_chunks = torch.split(q, lens, dim=1) + k_chunks = torch.split(k, lens, dim=1) + v_chunks = torch.split(v, lens, dim=1) + for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): q_i, k_i, v_i = ( rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 885e172d1e81a..608e90337f452 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -424,12 +424,12 @@ class Qwen2VisionAttention(nn.Module): k = k.contiguous() v = v.contiguous() outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] + + lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + q_chunks = torch.split(q, lens, dim=1) + k_chunks = torch.split(k, lens, dim=1) + v_chunks = torch.split(v, lens, dim=1) + for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): q_i, k_i, v_i = ( rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] )