[Model][Qwen3VL] Cache positional embedding indices (#28475)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Lukas Geiger 2025-11-15 19:03:09 +00:00 committed by GitHub
parent 637f292196
commit 07cadab27a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,7 +25,7 @@
"""Inference-only Qwen3VL model compatible with HuggingFace weights.""" """Inference-only Qwen3VL model compatible with HuggingFace weights."""
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial from functools import lru_cache, partial
from itertools import islice from itertools import islice
from typing import Any from typing import Any
@ -416,30 +416,41 @@ class Qwen3_VisionTransformer(nn.Module):
def device(self) -> torch.device: def device(self) -> torch.device:
return self.patch_embed.proj.weight.device return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw: list[list[int]]): @staticmethod
pos_ids = [] @lru_cache(maxsize=1024)
max_grid_size = max(max(h, w) for _, h, w in grid_thw) def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
for t, h, w in grid_thw: hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) h_div = h // spatial_merge_size
hpos_ids = hpos_ids.reshape( w_div = w // spatial_merge_size
h // self.spatial_merge_size, hpos_ids = hpos_ids.reshape(
self.spatial_merge_size, h_div,
w // self.spatial_merge_size, spatial_merge_size,
self.spatial_merge_size, w_div,
) spatial_merge_size,
hpos_ids = hpos_ids.permute(0, 2, 1, 3) )
hpos_ids = hpos_ids.flatten() hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
wpos_ids = wpos_ids.reshape( wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size, h_div,
self.spatial_merge_size, spatial_merge_size,
w // self.spatial_merge_size, w_div,
self.spatial_merge_size, spatial_merge_size,
) )
wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten() wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))
def rot_pos_emb(self, grid_thw: list[list[int]]):
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
pos_ids = [
self.rot_pos_ids(h, w, self.spatial_merge_size)
if t == 1
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
for t, h, w in grid_thw
]
pos_ids = torch.cat(pos_ids, dim=0) pos_ids = torch.cat(pos_ids, dim=0)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)