mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Model][Qwen3VL] Cache positional embedding indices (#28475)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
637f292196
commit
07cadab27a
@ -25,7 +25,7 @@
|
||||
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
||||
from functools import partial
|
||||
from functools import lru_cache, partial
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
@ -416,30 +416,41 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
def device(self) -> torch.device:
|
||||
return self.patch_embed.proj.weight.device
|
||||
|
||||
def rot_pos_emb(self, grid_thw: list[list[int]]):
|
||||
pos_ids = []
|
||||
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
||||
hpos_ids = hpos_ids.flatten()
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=1024)
|
||||
def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
|
||||
hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
|
||||
h_div = h // spatial_merge_size
|
||||
w_div = w // spatial_merge_size
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h_div,
|
||||
spatial_merge_size,
|
||||
w_div,
|
||||
spatial_merge_size,
|
||||
)
|
||||
hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
|
||||
hpos_ids = hpos_ids.flatten()
|
||||
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||
wpos_ids = wpos_ids.flatten()
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h_div,
|
||||
spatial_merge_size,
|
||||
w_div,
|
||||
spatial_merge_size,
|
||||
)
|
||||
wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
|
||||
wpos_ids = wpos_ids.flatten()
|
||||
|
||||
return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))
|
||||
|
||||
def rot_pos_emb(self, grid_thw: list[list[int]]):
|
||||
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
|
||||
pos_ids = [
|
||||
self.rot_pos_ids(h, w, self.spatial_merge_size)
|
||||
if t == 1
|
||||
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
|
||||
for t, h, w in grid_thw
|
||||
]
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user