mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 17:47:06 +08:00
[Perf] Remove sync point in vit torch sdpa attn backend (#30232)
Signed-off-by: Dazhi Jiang <dazhi_jiang@163.com>
This commit is contained in:
parent
cd00c443d2
commit
bcb6f5947f
@ -93,12 +93,12 @@ def torch_sdpa_wrapper(
|
|||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
outputs = []
|
outputs = []
|
||||||
for i in range(1, len(cu_seqlens)):
|
|
||||||
start_idx = cu_seqlens[i - 1]
|
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
end_idx = cu_seqlens[i]
|
q_chunks = torch.split(q, lens, dim=1)
|
||||||
q_i = q[:, start_idx:end_idx]
|
k_chunks = torch.split(k, lens, dim=1)
|
||||||
k_i = k[:, start_idx:end_idx]
|
v_chunks = torch.split(v, lens, dim=1)
|
||||||
v_i = v[:, start_idx:end_idx]
|
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||||
q_i, k_i, v_i = (
|
q_i, k_i, v_i = (
|
||||||
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -289,12 +289,12 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||||
# Execute attention entry by entry for speed & less VRAM.
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
outputs = []
|
outputs = []
|
||||||
for i in range(1, len(cu_seqlens)):
|
|
||||||
start_idx = cu_seqlens[i - 1]
|
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
end_idx = cu_seqlens[i]
|
q_chunks = torch.split(q, lens, dim=1)
|
||||||
q_i = q[:, start_idx:end_idx]
|
k_chunks = torch.split(k, lens, dim=1)
|
||||||
k_i = k[:, start_idx:end_idx]
|
v_chunks = torch.split(v, lens, dim=1)
|
||||||
v_i = v[:, start_idx:end_idx]
|
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||||
q_i, k_i, v_i = (
|
q_i, k_i, v_i = (
|
||||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -377,12 +377,12 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||||
# Execute attention entry by entry for speed & less VRAM.
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
outputs = []
|
outputs = []
|
||||||
for i in range(1, len(cu_seqlens)):
|
|
||||||
start_idx = cu_seqlens[i - 1]
|
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
end_idx = cu_seqlens[i]
|
q_chunks = torch.split(q, lens, dim=1)
|
||||||
q_i = q[:, start_idx:end_idx]
|
k_chunks = torch.split(k, lens, dim=1)
|
||||||
k_i = k[:, start_idx:end_idx]
|
v_chunks = torch.split(v, lens, dim=1)
|
||||||
v_i = v[:, start_idx:end_idx]
|
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||||
q_i, k_i, v_i = (
|
q_i, k_i, v_i = (
|
||||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -424,12 +424,12 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
outputs = []
|
outputs = []
|
||||||
for i in range(1, len(cu_seqlens)):
|
|
||||||
start_idx = cu_seqlens[i - 1]
|
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
end_idx = cu_seqlens[i]
|
q_chunks = torch.split(q, lens, dim=1)
|
||||||
q_i = q[:, start_idx:end_idx]
|
k_chunks = torch.split(k, lens, dim=1)
|
||||||
k_i = k[:, start_idx:end_idx]
|
v_chunks = torch.split(v, lens, dim=1)
|
||||||
v_i = v[:, start_idx:end_idx]
|
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||||
q_i, k_i, v_i = (
|
q_i, k_i, v_i = (
|
||||||
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user