diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index fa6b71bf92682..7f0c9372991d1 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -25,7 +25,7 @@ """Inference-only Qwen3VL model compatible with HuggingFace weights.""" from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence -from functools import partial +from functools import lru_cache, partial from itertools import islice from typing import Any @@ -416,30 +416,41 @@ class Qwen3_VisionTransformer(nn.Module): def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw: list[list[int]]): - pos_ids = [] - max_grid_size = max(max(h, w) for _, h, w in grid_thw) - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() + @staticmethod + @lru_cache(maxsize=1024) + def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: + hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w)) + h_div = h // spatial_merge_size + w_div = w // spatial_merge_size + hpos_ids = hpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + hpos_ids = hpos_ids.transpose(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w)) + wpos_ids = wpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + wpos_ids = wpos_ids.transpose(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + + return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1)) + + def rot_pos_emb(self, grid_thw: list[list[int]]): + max_grid_size = max(max(h, w) for _, h, w in grid_thw) + pos_ids = [ + self.rot_pos_ids(h, w, self.spatial_merge_size) + if t == 1 + else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1) + for t, h, w in grid_thw + ] pos_ids = torch.cat(pos_ids, dim=0) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)