[Performance][MM] Building the inverse permutation in O(n) time in Qwen2_5_VisionTransformer (#24443)

Signed-off-by: Junhong <liujunhong11@huawei.com>
Co-authored-by: Junhong <liujunhong11@huawei.com>
This commit is contained in:
WeiQing Chen 2025-09-09 15:24:11 +08:00 committed by GitHub
parent 46876dff32
commit e283976f3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -717,6 +717,15 @@ class Qwen2_5_VisionTransformer(nn.Module):
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens return max_seqlen, seqlens
@staticmethod
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
# building the inverse permutation in O(n) time
inv = torch.empty_like(perm)
inv[perm] = torch.arange(perm.numel(),
device=perm.device,
dtype=perm.dtype)
return inv
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -760,6 +769,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
rotary_pos_emb = torch.cat(rotary_pos_emb) rotary_pos_emb = torch.cat(rotary_pos_emb)
window_index = torch.cat(window_index) window_index = torch.cat(window_index)
# compute reverse indices
reverse_indices = self.invert_permutation(window_index)
cu_window_seqlens = torch.cat(cu_window_seqlens) cu_window_seqlens = torch.cat(cu_window_seqlens)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
cu_seqlens = torch.cat(cu_seqlens) cu_seqlens = torch.cat(cu_seqlens)
@ -813,7 +824,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
# adapter # adapter
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :] hidden_states = hidden_states[reverse_indices, :]
return hidden_states return hidden_states