[Model][Qwen3VL] Cache positional embedding indices (#28475)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Lukas Geiger 2025-11-15 19:03:09 +00:00 committed by GitHub
parent 637f292196
commit 07cadab27a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,7 +25,7 @@
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial
from functools import lru_cache, partial
from itertools import islice
from typing import Any
@ -416,30 +416,41 @@ class Qwen3_VisionTransformer(nn.Module):
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw: list[list[int]]):
pos_ids = []
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
@staticmethod
@lru_cache(maxsize=1024)
def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
h_div = h // spatial_merge_size
w_div = w // spatial_merge_size
hpos_ids = hpos_ids.reshape(
h_div,
spatial_merge_size,
w_div,
spatial_merge_size,
)
hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
wpos_ids = wpos_ids.reshape(
h_div,
spatial_merge_size,
w_div,
spatial_merge_size,
)
wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))
def rot_pos_emb(self, grid_thw: list[list[int]]):
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
pos_ids = [
self.rot_pos_ids(h, w, self.spatial_merge_size)
if t == 1
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
for t, h, w in grid_thw
]
pos_ids = torch.cat(pos_ids, dim=0)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)