mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 04:24:56 +08:00
[Perf] Further optimization for Qwen3-VL fast_pos_embed_interpolate (#25347)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
1c3ffdbecc
commit
af7dfb0d1a
@ -405,25 +405,39 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
dh = h_idxs - h_floor
|
dh = h_idxs - h_floor
|
||||||
dw = w_idxs - w_floor
|
dw = w_idxs - w_floor
|
||||||
|
|
||||||
w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1)
|
# Create meshgrid view for all h, w vars
|
||||||
w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1)
|
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')
|
||||||
w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1)
|
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor,
|
||||||
w11 = (dh[:, None] * dw[None, :]).reshape(-1)
|
w_floor,
|
||||||
|
indexing='ij')
|
||||||
|
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil,
|
||||||
|
w_ceil,
|
||||||
|
indexing='ij')
|
||||||
|
h_floor_grid_idx = h_floor_grid * num_grid_per_side
|
||||||
|
h_ceil_grid_idx = h_ceil_grid * num_grid_per_side
|
||||||
|
|
||||||
idx00 = (h_floor[:, None] * num_grid_per_side +
|
# original computation of weights
|
||||||
w_floor[None, :]).reshape(-1)
|
# w00 = (1 - dh_grid) * (1 - dw_grid)
|
||||||
idx01 = (h_floor[:, None] * num_grid_per_side +
|
# w01 = (1 - dh_grid) * dw_grid
|
||||||
w_ceil[None, :]).reshape(-1)
|
# w10 = dh_grid * (1 - dw_grid)
|
||||||
idx10 = (h_ceil[:, None] * num_grid_per_side +
|
# w11 = dh_grid * dw_grid
|
||||||
w_floor[None, :]).reshape(-1)
|
# we reuse w11 here to avoid duplicate
|
||||||
idx11 = (h_ceil[:, None] * num_grid_per_side +
|
# dh_grid * dw_grid computation
|
||||||
w_ceil[None, :]).reshape(-1)
|
w11 = dh_grid * dw_grid
|
||||||
|
w10 = dh_grid - w11
|
||||||
|
w01 = dw_grid - w11
|
||||||
|
w00 = 1 - dh_grid - dw_grid + w11
|
||||||
|
|
||||||
indices = torch.stack([idx00, idx01, idx10, idx11], dim=0)
|
idx00 = h_floor_grid_idx + w_floor_grid
|
||||||
|
idx01 = h_floor_grid_idx + w_ceil_grid
|
||||||
|
idx10 = h_ceil_grid_idx + w_floor_grid
|
||||||
|
idx11 = h_ceil_grid_idx + w_ceil_grid
|
||||||
|
|
||||||
|
indices = torch.stack([idx00, idx01, idx10, idx11],
|
||||||
|
dim=0).reshape(4, -1)
|
||||||
weights = torch.stack([w00, w01, w10, w11],
|
weights = torch.stack([w00, w01, w10, w11],
|
||||||
dim=0).to(dtype=self.dtype,
|
dim=0).reshape(4, -1, 1)
|
||||||
device=self.device)
|
weights = weights.to(dtype=self.dtype, device=self.device)
|
||||||
weights = weights.unsqueeze(-1)
|
|
||||||
|
|
||||||
embeds = self.pos_embed(indices)
|
embeds = self.pos_embed(indices)
|
||||||
weighted_embeds = embeds * weights
|
weighted_embeds = embeds * weights
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user