[Perf] Remove sync point in vit torch sdpa attn backend (#30232)

Signed-off-by: Dazhi Jiang <dazhi_jiang@163.com>
This commit is contained in:
Dazhi Jiang 2025-12-08 15:12:42 +08:00 committed by GitHub
parent cd00c443d2
commit bcb6f5947f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 24 deletions

View File

@ -93,12 +93,12 @@ def torch_sdpa_wrapper(
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)

View File

@ -289,12 +289,12 @@ class Ernie4_5_VisionAttention(nn.Module):
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)

View File

@ -377,12 +377,12 @@ class Glm4vVisionAttention(nn.Module):
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)

View File

@ -424,12 +424,12 @@ class Qwen2VisionAttention(nn.Module):
k = k.contiguous()
v = v.contiguous()
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)