diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index aa28c07ddceb..98d65dea2739 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -405,25 +405,39 @@ class Qwen3_VisionTransformer(nn.Module): dh = h_idxs - h_floor dw = w_idxs - w_floor - w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1) - w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1) - w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1) - w11 = (dh[:, None] * dw[None, :]).reshape(-1) + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij') + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, + w_floor, + indexing='ij') + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, + w_ceil, + indexing='ij') + h_floor_grid_idx = h_floor_grid * num_grid_per_side + h_ceil_grid_idx = h_ceil_grid * num_grid_per_side - idx00 = (h_floor[:, None] * num_grid_per_side + - w_floor[None, :]).reshape(-1) - idx01 = (h_floor[:, None] * num_grid_per_side + - w_ceil[None, :]).reshape(-1) - idx10 = (h_ceil[:, None] * num_grid_per_side + - w_floor[None, :]).reshape(-1) - idx11 = (h_ceil[:, None] * num_grid_per_side + - w_ceil[None, :]).reshape(-1) + # original computation of weights + # w00 = (1 - dh_grid) * (1 - dw_grid) + # w01 = (1 - dh_grid) * dw_grid + # w10 = dh_grid * (1 - dw_grid) + # w11 = dh_grid * dw_grid + # we reuse w11 here to avoid duplicate + # dh_grid * dw_grid computation + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - dw_grid + w11 - indices = torch.stack([idx00, idx01, idx10, idx11], dim=0) + idx00 = h_floor_grid_idx + w_floor_grid + idx01 = h_floor_grid_idx + w_ceil_grid + idx10 = h_ceil_grid_idx + w_floor_grid + idx11 = h_ceil_grid_idx + w_ceil_grid + + indices = torch.stack([idx00, idx01, idx10, idx11], + dim=0).reshape(4, -1) weights = torch.stack([w00, w01, w10, w11], - dim=0).to(dtype=self.dtype, - device=self.device) - weights = weights.unsqueeze(-1) + dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=self.dtype, device=self.device) embeds = self.pos_embed(indices) weighted_embeds = embeds * weights