mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:55:01 +08:00
[Model][QwenVL] Simplify cos/sin rotary embedding indexing (#28962)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
6a25ea5f0e
commit
3d4e7d34be
@ -797,13 +797,8 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
# Use pre-computed cos_sin_cache from RotaryEmbedding
|
||||
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
|
||||
|
||||
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
|
||||
cos_w = cos[pos_ids[:, 1]]
|
||||
sin_h = sin[pos_ids[:, 0]]
|
||||
sin_w = sin[pos_ids[:, 1]]
|
||||
|
||||
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
|
||||
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
|
||||
cos_combined = cos[pos_ids].flatten(1)
|
||||
sin_combined = sin[pos_ids].flatten(1)
|
||||
return cos_combined, sin_combined, pos_ids
|
||||
|
||||
def compute_attn_mask_seqlen(
|
||||
|
||||
@ -738,13 +738,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
# Use pre-computed cos_sin_cache from RotaryEmbedding
|
||||
cos, sin = self.rotary_pos_emb.get_cos_sin(max_size)
|
||||
|
||||
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
|
||||
cos_w = cos[pos_ids[:, 1]]
|
||||
sin_h = sin[pos_ids[:, 0]]
|
||||
sin_w = sin[pos_ids[:, 1]]
|
||||
|
||||
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
|
||||
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
|
||||
cos_combined = cos[pos_ids].flatten(1)
|
||||
sin_combined = sin[pos_ids].flatten(1)
|
||||
|
||||
cos_combined = cos_combined.reshape(
|
||||
cos_combined.shape[0] // self.spatial_merge_unit,
|
||||
|
||||
@ -724,13 +724,8 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
# Use pre-computed cos_sin_cache from RotaryEmbedding
|
||||
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
|
||||
|
||||
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
|
||||
cos_w = cos[pos_ids[:, 1]]
|
||||
sin_h = sin[pos_ids[:, 0]]
|
||||
sin_w = sin[pos_ids[:, 1]]
|
||||
|
||||
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
|
||||
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
|
||||
cos_combined = cos[pos_ids].flatten(1)
|
||||
sin_combined = sin[pos_ids].flatten(1)
|
||||
return cos_combined, sin_combined
|
||||
|
||||
def compute_attn_mask_seqlen(
|
||||
|
||||
@ -428,13 +428,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
# Use pre-computed cos_sin_cache from RotaryEmbedding
|
||||
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
|
||||
|
||||
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
|
||||
cos_w = cos[pos_ids[:, 1]]
|
||||
sin_h = sin[pos_ids[:, 0]]
|
||||
sin_w = sin[pos_ids[:, 1]]
|
||||
|
||||
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
|
||||
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
|
||||
cos_combined = cos[pos_ids].flatten(1)
|
||||
sin_combined = sin[pos_ids].flatten(1)
|
||||
|
||||
return cos_combined, sin_combined
|
||||
|
||||
|
||||
@ -459,18 +459,13 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
|
||||
for t, h, w in grid_thw
|
||||
]
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
|
||||
|
||||
# Use pre-computed cos_sin_cache from RotaryEmbedding
|
||||
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
|
||||
|
||||
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
|
||||
cos_w = cos[pos_ids[:, 1]]
|
||||
sin_h = sin[pos_ids[:, 0]]
|
||||
sin_w = sin[pos_ids[:, 1]]
|
||||
|
||||
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
|
||||
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
|
||||
cos_combined = cos[pos_ids].flatten(1)
|
||||
sin_combined = sin[pos_ids].flatten(1)
|
||||
|
||||
return cos_combined, sin_combined
|
||||
|
||||
@ -566,12 +561,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
|
||||
rotary_pos_emb_cos = rotary_pos_emb_cos.to(
|
||||
hidden_states.device, non_blocking=True
|
||||
)
|
||||
rotary_pos_emb_sin = rotary_pos_emb_sin.to(
|
||||
hidden_states.device, non_blocking=True
|
||||
)
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user