[Model][QwenVL] Simplify cos/sin rotary embedding indexing (#28962)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-11-19 05:43:01 +00:00 committed by GitHub
parent 6a25ea5f0e
commit 3d4e7d34be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 11 additions and 42 deletions

View File

@ -797,13 +797,8 @@ class Glm4vVisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined, pos_ids
def compute_attn_mask_seqlen(

View File

@ -738,13 +738,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_size)
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)
cos_combined = cos_combined.reshape(
cos_combined.shape[0] // self.spatial_merge_unit,

View File

@ -724,13 +724,8 @@ class Qwen2VisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined
def compute_attn_mask_seqlen(

View File

@ -428,13 +428,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined

View File

@ -459,18 +459,13 @@ class Qwen3_VisionTransformer(nn.Module):
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
for t, h, w in grid_thw
]
pos_ids = torch.cat(pos_ids, dim=0)
pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined
@ -566,12 +561,6 @@ class Qwen3_VisionTransformer(nn.Module):
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(
hidden_states.device, non_blocking=True
)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(
hidden_states.device, non_blocking=True
)
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]