mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[Performance][MM] Building the inverse permutation in O(n) time in Qwen2_5_VisionTransformer (#24443)
Signed-off-by: Junhong <liujunhong11@huawei.com> Co-authored-by: Junhong <liujunhong11@huawei.com>
This commit is contained in:
parent
46876dff32
commit
e283976f3a
@ -717,6 +717,15 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
return max_seqlen, seqlens
|
||||
|
||||
@staticmethod
|
||||
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
|
||||
# building the inverse permutation in O(n) time
|
||||
inv = torch.empty_like(perm)
|
||||
inv[perm] = torch.arange(perm.numel(),
|
||||
device=perm.device,
|
||||
dtype=perm.dtype)
|
||||
return inv
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -760,6 +769,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
|
||||
rotary_pos_emb = torch.cat(rotary_pos_emb)
|
||||
window_index = torch.cat(window_index)
|
||||
# compute reverse indices
|
||||
reverse_indices = self.invert_permutation(window_index)
|
||||
cu_window_seqlens = torch.cat(cu_window_seqlens)
|
||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||
cu_seqlens = torch.cat(cu_seqlens)
|
||||
@ -813,7 +824,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
|
||||
# adapter
|
||||
hidden_states = self.merger(hidden_states)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
hidden_states = hidden_states[reverse_indices, :]
|
||||
return hidden_states
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user