mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 12:34:57 +08:00
[Model][Qwen3VL] Cache positional embedding indices (#28475)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
637f292196
commit
07cadab27a
@ -25,7 +25,7 @@
|
|||||||
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import lru_cache, partial
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -416,30 +416,41 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return self.patch_embed.proj.weight.device
|
return self.patch_embed.proj.weight.device
|
||||||
|
|
||||||
def rot_pos_emb(self, grid_thw: list[list[int]]):
|
@staticmethod
|
||||||
pos_ids = []
|
@lru_cache(maxsize=1024)
|
||||||
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
|
def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
|
||||||
for t, h, w in grid_thw:
|
hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
|
||||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
h_div = h // spatial_merge_size
|
||||||
hpos_ids = hpos_ids.reshape(
|
w_div = w // spatial_merge_size
|
||||||
h // self.spatial_merge_size,
|
hpos_ids = hpos_ids.reshape(
|
||||||
self.spatial_merge_size,
|
h_div,
|
||||||
w // self.spatial_merge_size,
|
spatial_merge_size,
|
||||||
self.spatial_merge_size,
|
w_div,
|
||||||
)
|
spatial_merge_size,
|
||||||
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
)
|
||||||
hpos_ids = hpos_ids.flatten()
|
hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
|
||||||
|
hpos_ids = hpos_ids.flatten()
|
||||||
|
|
||||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
|
||||||
wpos_ids = wpos_ids.reshape(
|
wpos_ids = wpos_ids.reshape(
|
||||||
h // self.spatial_merge_size,
|
h_div,
|
||||||
self.spatial_merge_size,
|
spatial_merge_size,
|
||||||
w // self.spatial_merge_size,
|
w_div,
|
||||||
self.spatial_merge_size,
|
spatial_merge_size,
|
||||||
)
|
)
|
||||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
|
||||||
wpos_ids = wpos_ids.flatten()
|
wpos_ids = wpos_ids.flatten()
|
||||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
||||||
|
return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))
|
||||||
|
|
||||||
|
def rot_pos_emb(self, grid_thw: list[list[int]]):
|
||||||
|
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
|
||||||
|
pos_ids = [
|
||||||
|
self.rot_pos_ids(h, w, self.spatial_merge_size)
|
||||||
|
if t == 1
|
||||||
|
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
|
||||||
|
for t, h, w in grid_thw
|
||||||
|
]
|
||||||
pos_ids = torch.cat(pos_ids, dim=0)
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user