mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 17:26:43 +08:00
[Models][Qwen] Replace pad with cat for better performance (#26486)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
e246ad6f0c
commit
2c1c7dfb35
@ -680,7 +680,7 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
dim=0,
|
dim=0,
|
||||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||||
)
|
)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
|
||||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
|
|||||||
@ -574,11 +574,12 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
|||||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
).cumsum(dim=0, dtype=torch.int32)
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
|
|
||||||
|
zeros = cu_seqlens.new_zeros(1)
|
||||||
if num_pad > 0:
|
if num_pad > 0:
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0)
|
cu_seqlens = torch.cat([zeros, cu_seqlens, zeros])
|
||||||
cu_seqlens[-1] = cu_seqlens[-2] + num_pad
|
cu_seqlens[-1] = cu_seqlens[-2] + num_pad
|
||||||
else:
|
else:
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
cu_seqlens = torch.cat([zeros, cu_seqlens])
|
||||||
|
|
||||||
# add batch size
|
# add batch size
|
||||||
if hidden_states.ndim == 2:
|
if hidden_states.ndim == 2:
|
||||||
|
|||||||
@ -539,7 +539,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
dim=0,
|
dim=0,
|
||||||
dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32,
|
dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32,
|
||||||
)
|
)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
|
||||||
hidden_states = hidden_states.unsqueeze(1)
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
|
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
|
||||||
|
|||||||
@ -592,7 +592,7 @@ class Siglip2Encoder(nn.Module):
|
|||||||
# for more information
|
# for more information
|
||||||
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
|
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
|
||||||
)
|
)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
|
||||||
reverse_indices = torch.argsort(window_index)
|
reverse_indices = torch.argsort(window_index)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user