mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 00:24:30 +08:00
[Perf] Remove sync point in vit torch sdpa attn backend (#30232)
Signed-off-by: Dazhi Jiang <dazhi_jiang@163.com>
This commit is contained in:
parent
cd00c443d2
commit
bcb6f5947f
@ -93,12 +93,12 @@ def torch_sdpa_wrapper(
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
end_idx = cu_seqlens[i]
|
||||
q_i = q[:, start_idx:end_idx]
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_chunks = torch.split(q, lens, dim=1)
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
q_i, k_i, v_i = (
|
||||
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
)
|
||||
|
||||
@ -289,12 +289,12 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
end_idx = cu_seqlens[i]
|
||||
q_i = q[:, start_idx:end_idx]
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_chunks = torch.split(q, lens, dim=1)
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
)
|
||||
|
||||
@ -377,12 +377,12 @@ class Glm4vVisionAttention(nn.Module):
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
end_idx = cu_seqlens[i]
|
||||
q_i = q[:, start_idx:end_idx]
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_chunks = torch.split(q, lens, dim=1)
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
)
|
||||
|
||||
@ -424,12 +424,12 @@ class Qwen2VisionAttention(nn.Module):
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
end_idx = cu_seqlens[i]
|
||||
q_i = q[:, start_idx:end_idx]
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_chunks = torch.split(q, lens, dim=1)
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
q_i, k_i, v_i = (
|
||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user