[Model][Perf] Use cos and sin cache in QwenVL (#28798)

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
This commit is contained in:
Canlin Guo 2025-11-18 19:51:54 +08:00 committed by GitHub
parent 285eaa4285
commit b9489f51e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 218 additions and 217 deletions

View File

@ -83,6 +83,11 @@ class RotaryEmbeddingBase(CustomOp):
):
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]:
cos_sin = self.cos_sin_cache[:seqlen]
cos, sin = cos_sin.chunk(2, dim=-1)
return cos, sin
class RotaryEmbedding(RotaryEmbeddingBase):
def __init__(

View File

@ -65,6 +65,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -341,7 +342,8 @@ class Glm4vVisionAttention(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
@ -353,10 +355,12 @@ class Glm4vVisionAttention(nn.Module):
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb is not None:
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
qk_rotated = apply_rotary_pos_emb_vision(
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
@ -454,14 +458,16 @@ class Glm4vVisionBlock(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
@ -660,44 +666,6 @@ class Glm4vVisionEmbeddings(nn.Module):
return embeddings
class Glm4vVisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._freqs_cached = None
def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (
self.theta
** (
torch.arange(
0,
self.dim,
2,
dtype=torch.float,
device=self.inv_freq.device,
)
/ self.dim
)
)
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
self._freqs_cached = freqs
def forward(self, seqlen: int) -> torch.Tensor:
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]
class Glm4vVisionTransformer(nn.Module):
def __init__(
self,
@ -731,7 +699,13 @@ class Glm4vVisionTransformer(nn.Module):
norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
base=10000.0,
is_neox_style=True,
)
self.blocks = nn.ModuleList(
[
Glm4vVisionBlock(
@ -789,7 +763,9 @@ class Glm4vVisionTransformer(nn.Module):
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
def rot_pos_emb(
self, grid_thw: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
@ -817,9 +793,18 @@ class Glm4vVisionTransformer(nn.Module):
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
return cos_combined, sin_combined, pos_ids
def compute_attn_mask_seqlen(
self,
@ -848,7 +833,9 @@ class Glm4vVisionTransformer(nn.Module):
x = self.post_conv_layernorm(x)
# compute position embedding
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb(
grid_thw
)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@ -867,7 +854,8 @@ class Glm4vVisionTransformer(nn.Module):
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)

View File

@ -64,6 +64,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
@ -363,7 +364,8 @@ class Qwen2_5_VisionAttention(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
@ -378,13 +380,15 @@ class Qwen2_5_VisionAttention(nn.Module):
head=self.num_attention_heads_per_partition,
)
if rotary_pos_emb is not None:
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
qk, v = qkv[:, :, :2], qkv[:, :, 2]
qk_reshaped = einops.rearrange(
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
)
qk_rotated = apply_rotary_pos_emb_vision(qk_reshaped, rotary_pos_emb)
qk_rotated = apply_rotary_pos_emb_vision(
qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin
)
qk_rotated = qk_rotated.view(
2,
batch_size,
@ -434,7 +438,8 @@ class Qwen2_5_VisionAttention(nn.Module):
dynamic_arg_dims={
"x": 0,
"cu_seqlens": 0,
"rotary_pos_emb": 0,
"rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0,
"seqlens": 0,
},
mark_unbacked_dims={"seqlens": 0},
@ -485,14 +490,16 @@ class Qwen2_5_VisionBlock(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
@ -588,42 +595,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
return out
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (
theta ** (torch.arange(0, dim, 2, dtype=torch.float, device="cpu") / dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._freqs_cached = None
def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (
self.theta
** (
torch.arange(
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
)
/ self.dim
)
)
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
self._freqs_cached = freqs
def forward(self, seqlen: int) -> torch.Tensor:
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]
class Qwen2_5_VisionTransformer(nn.Module):
def __init__(
self,
@ -666,7 +637,13 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
base=10000.0,
is_neox_style=True,
)
use_upstream_fa = False
self.attn_backend = get_vit_attn_backend(
@ -757,15 +734,30 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
max_size = max(h, w)
rotary_pos_emb_full = self.rotary_pos_emb(max_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
rotary_pos_emb = rotary_pos_emb.reshape(
rotary_pos_emb.shape[0] // self.spatial_merge_unit,
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_size)
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
cos_combined = cos_combined.reshape(
cos_combined.shape[0] // self.spatial_merge_unit,
self.spatial_merge_unit,
-1,
)
sin_combined = sin_combined.reshape(
sin_combined.shape[0] // self.spatial_merge_unit,
self.spatial_merge_unit,
-1,
)
return rotary_pos_emb
return cos_combined, sin_combined
def get_window_index_thw(self, grid_t, grid_h, grid_w):
vit_merger_window_size = (
@ -807,14 +799,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
@lru_cache(maxsize=1024) # noqa: B019
def get_rope_by_thw(self, t, h, w):
window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(t, h, w)
rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :]
rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1)
cos_thw, sin_thw = self.rotary_pos_emb_thw(t, h, w)
cos_thw = cos_thw[window_index_thw, :, :]
cos_thw = cos_thw.flatten(start_dim=0, end_dim=1)
sin_thw = sin_thw[window_index_thw, :, :]
sin_thw = sin_thw.flatten(start_dim=0, end_dim=1)
cu_seqlens_thw = torch.repeat_interleave(
torch.tensor([h * w], dtype=torch.int32), t
)
return (
rotary_pos_emb_thw,
cos_thw,
sin_thw,
window_index_thw,
cu_seqlens_window_thw,
cu_seqlens_thw,
@ -849,7 +846,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
) -> torch.Tensor:
# patchify
seq_len, _ = x.size()
rotary_pos_emb = []
rotary_pos_emb_cos = []
rotary_pos_emb_sin = []
window_index: list = []
cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)]
cu_seqlens: list = []
@ -865,7 +863,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
llm_w = w // self.spatial_merge_size
(
rotary_pos_emb_thw,
cos_thw,
sin_thw,
window_index_thw,
cu_seqlens_window_thw,
cu_seqlens_thw,
@ -878,11 +877,13 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_window_seqlens_last = cu_seqlens_window_thw[-1]
cu_window_seqlens.append(cu_seqlens_window_thw)
rotary_pos_emb.append(rotary_pos_emb_thw)
rotary_pos_emb_cos.append(cos_thw)
rotary_pos_emb_sin.append(sin_thw)
cu_seqlens.append(cu_seqlens_thw)
rotary_pos_emb = torch.cat(rotary_pos_emb)
rotary_pos_emb_cos = torch.cat(rotary_pos_emb_cos)
rotary_pos_emb_sin = torch.cat(rotary_pos_emb_sin)
window_index = torch.cat(window_index)
# compute reverse indices
reverse_indices = self.invert_permutation(window_index)
@ -901,7 +902,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
rotary_pos_emb = rotary_pos_emb.to(device=self.device, non_blocking=True)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(
device=self.device, non_blocking=True
)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(
device=self.device, non_blocking=True
)
window_index = window_index.to(device=hidden_states.device, non_blocking=True)
reverse_indices = reverse_indices.to(
device=hidden_states.device, non_blocking=True
@ -928,7 +934,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen_now,
seqlens=seqlens_now,
)

View File

@ -32,7 +32,7 @@ from typing import Annotated, Any, Literal, TypeAlias
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops import rearrange
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
@ -59,7 +59,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch,
dispatch_rotary_emb_function,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -275,47 +277,13 @@ class Qwen2VisionMLP(nn.Module):
return x
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
def apply_rotary_pos_emb_vision(
t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
rotary_emb_function = dispatch_rotary_emb_function(
default=partial(apply_rotary_emb_torch, is_neox_style=True)
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = rotary_emb_function(t_, cos, sin).type_as(t)
output = rotary_emb_function(t, cos, sin).type_as(t)
return output
@ -412,7 +380,8 @@ class Qwen2VisionAttention(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
@ -424,11 +393,13 @@ class Qwen2VisionAttention(nn.Module):
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
if rotary_pos_emb is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
@ -534,14 +505,16 @@ class Qwen2VisionBlock(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
@ -628,40 +601,6 @@ class Qwen2VisionPatchMerger(nn.Module):
return out
class Qwen2VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._freqs_cached = None
def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (
self.theta
** (
torch.arange(
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
)
/ self.dim
)
)
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
self._freqs_cached = freqs
def forward(self, seqlen: int) -> torch.Tensor:
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]
class Qwen2VisionTransformer(nn.Module):
def __init__(
self,
@ -700,7 +639,13 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = embed_dim // num_heads
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
base=10000.0,
is_neox_style=True,
)
self.blocks = nn.ModuleList(
[
@ -744,7 +689,9 @@ class Qwen2VisionTransformer(nn.Module):
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
def rot_pos_emb(
self, grid_thw: list[list[int]]
) -> tuple[torch.Tensor, torch.Tensor]:
pos_ids = []
max_grid_size = 0
for t, h, w in grid_thw:
@ -773,9 +720,18 @@ class Qwen2VisionTransformer(nn.Module):
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
max_grid_size = max(max_grid_size, h, w)
pos_ids = torch.cat(pos_ids, dim=0)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
return cos_combined, sin_combined
def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
@ -806,7 +762,7 @@ class Qwen2VisionTransformer(nn.Module):
grid_thw_list = grid_thw.tolist()
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
@ -824,7 +780,8 @@ class Qwen2VisionTransformer(nn.Module):
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)

View File

@ -60,6 +60,7 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo
@ -90,7 +91,6 @@ from .qwen2_5_omni_thinker import (
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
Qwen2_5_VisionRotaryEmbedding,
Qwen2_5_VLProcessingInfo,
)
from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
@ -221,14 +221,16 @@ class Qwen3_VisionBlock(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
@ -332,7 +334,13 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
base=10000.0,
is_neox_style=True,
)
self.blocks = nn.ModuleList(
[
@ -416,9 +424,19 @@ class Qwen3Omni_VisionTransformer(nn.Module):
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
return cos_combined, sin_combined
def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
num_grid_per_side = self.num_grid_per_side
@ -508,7 +526,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
if self.apply_vit_abs_pos_embed:
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@ -519,7 +537,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
hidden_states_list = []
@ -529,7 +548,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)

View File

@ -63,6 +63,7 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -95,7 +96,6 @@ from .interfaces import (
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
Qwen2_5_VisionRotaryEmbedding,
Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs,
Qwen2_5_VLImagePixelInputs,
@ -232,14 +232,16 @@ class Qwen3_VisionBlock(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
@ -339,7 +341,13 @@ class Qwen3_VisionTransformer(nn.Module):
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
base=10000.0,
is_neox_style=True,
)
self.merger = Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
@ -452,9 +460,19 @@ class Qwen3_VisionTransformer(nn.Module):
for t, h, w in grid_thw
]
pos_ids = torch.cat(pos_ids, dim=0)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
return cos_combined, sin_combined
def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
num_grid_per_side = self.num_grid_per_side
@ -547,8 +565,13 @@ class Qwen3_VisionTransformer(nn.Module):
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True)
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(
hidden_states.device, non_blocking=True
)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(
hidden_states.device, non_blocking=True
)
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@ -564,7 +587,8 @@ class Qwen3_VisionTransformer(nn.Module):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)