mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[Performance][MM] Building the inverse permutation in O(n) time in Qwen2_5_VisionTransformer (#24443)
Signed-off-by: Junhong <liujunhong11@huawei.com> Co-authored-by: Junhong <liujunhong11@huawei.com>
This commit is contained in:
parent
46876dff32
commit
e283976f3a
@ -717,6 +717,15 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
return max_seqlen, seqlens
|
return max_seqlen, seqlens
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
|
||||||
|
# building the inverse permutation in O(n) time
|
||||||
|
inv = torch.empty_like(perm)
|
||||||
|
inv[perm] = torch.arange(perm.numel(),
|
||||||
|
device=perm.device,
|
||||||
|
dtype=perm.dtype)
|
||||||
|
return inv
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -760,6 +769,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
rotary_pos_emb = torch.cat(rotary_pos_emb)
|
rotary_pos_emb = torch.cat(rotary_pos_emb)
|
||||||
window_index = torch.cat(window_index)
|
window_index = torch.cat(window_index)
|
||||||
|
# compute reverse indices
|
||||||
|
reverse_indices = self.invert_permutation(window_index)
|
||||||
cu_window_seqlens = torch.cat(cu_window_seqlens)
|
cu_window_seqlens = torch.cat(cu_window_seqlens)
|
||||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||||
cu_seqlens = torch.cat(cu_seqlens)
|
cu_seqlens = torch.cat(cu_seqlens)
|
||||||
@ -813,7 +824,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
# adapter
|
# adapter
|
||||||
hidden_states = self.merger(hidden_states)
|
hidden_states = self.merger(hidden_states)
|
||||||
reverse_indices = torch.argsort(window_index)
|
|
||||||
hidden_states = hidden_states[reverse_indices, :]
|
hidden_states = hidden_states[reverse_indices, :]
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user