From 3d4e7d34be856cc4f54033e6a019059afacb5e76 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Wed, 19 Nov 2025 05:43:01 +0000 Subject: [PATCH] [Model][QwenVL] Simplify cos/sin rotary embedding indexing (#28962) Signed-off-by: Lukas Geiger --- vllm/model_executor/models/glm4_1v.py | 9 ++------- vllm/model_executor/models/qwen2_5_vl.py | 9 ++------- vllm/model_executor/models/qwen2_vl.py | 9 ++------- .../models/qwen3_omni_moe_thinker.py | 9 ++------- vllm/model_executor/models/qwen3_vl.py | 17 +++-------------- 5 files changed, 11 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 2c2f45c2453ee..7a4fee76ae6b3 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -797,13 +797,8 @@ class Glm4vVisionTransformer(nn.Module): # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) - cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) - cos_w = cos[pos_ids[:, 1]] - sin_h = sin[pos_ids[:, 0]] - sin_w = sin[pos_ids[:, 1]] - - cos_combined = torch.cat([cos_h, cos_w], dim=-1) - sin_combined = torch.cat([sin_h, sin_w], dim=-1) + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined, pos_ids def compute_attn_mask_seqlen( diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 2e4fd9645d88f..5b5d50ec8935a 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -738,13 +738,8 @@ class Qwen2_5_VisionTransformer(nn.Module): # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_size) - cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) - cos_w = cos[pos_ids[:, 1]] - sin_h = sin[pos_ids[:, 0]] - sin_w = sin[pos_ids[:, 1]] - - cos_combined = torch.cat([cos_h, cos_w], dim=-1) - sin_combined = torch.cat([sin_h, sin_w], dim=-1) + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) cos_combined = cos_combined.reshape( cos_combined.shape[0] // self.spatial_merge_unit, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 53df5972a8fe1..cda8eaf5377f1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -724,13 +724,8 @@ class Qwen2VisionTransformer(nn.Module): # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) - cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) - cos_w = cos[pos_ids[:, 1]] - sin_h = sin[pos_ids[:, 0]] - sin_w = sin[pos_ids[:, 1]] - - cos_combined = torch.cat([cos_h, cos_w], dim=-1) - sin_combined = torch.cat([sin_h, sin_w], dim=-1) + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined def compute_attn_mask_seqlen( diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 8274b92138f78..d2fd74a5e41ad 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -428,13 +428,8 @@ class Qwen3Omni_VisionTransformer(nn.Module): # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) - cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) - cos_w = cos[pos_ids[:, 1]] - sin_h = sin[pos_ids[:, 0]] - sin_w = sin[pos_ids[:, 1]] - - cos_combined = torch.cat([cos_h, cos_w], dim=-1) - sin_combined = torch.cat([sin_h, sin_w], dim=-1) + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 99a4007ef7f23..0c546309400b7 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -459,18 +459,13 @@ class Qwen3_VisionTransformer(nn.Module): else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1) for t, h, w in grid_thw ] - pos_ids = torch.cat(pos_ids, dim=0) + pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True) # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) - cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) - cos_w = cos[pos_ids[:, 1]] - sin_h = sin[pos_ids[:, 0]] - sin_w = sin[pos_ids[:, 1]] - - cos_combined = torch.cat([cos_h, cos_w], dim=-1) - sin_combined = torch.cat([sin_h, sin_w], dim=-1) + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined @@ -566,12 +561,6 @@ class Qwen3_VisionTransformer(nn.Module): pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) hidden_states = hidden_states + pos_embeds rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) - rotary_pos_emb_cos = rotary_pos_emb_cos.to( - hidden_states.device, non_blocking=True - ) - rotary_pos_emb_sin = rotary_pos_emb_sin.to( - hidden_states.device, non_blocking=True - ) cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]