mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 06:17:05 +08:00
[MM][Perf] Minor Optimization on Qwen3-VL fast_pos_embed_interpolate (#25337)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
cf56cf78b4
commit
30d08911f7
@ -270,6 +270,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
self.temporal_patch_size = vision_config.temporal_patch_size
|
self.temporal_patch_size = vision_config.temporal_patch_size
|
||||||
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
|
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
|
||||||
self.use_data_parallel = use_data_parallel
|
self.use_data_parallel = use_data_parallel
|
||||||
|
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
|
||||||
|
|
||||||
# NOTE: This is used for creating empty tensor for all_gather for
|
# NOTE: This is used for creating empty tensor for all_gather for
|
||||||
# DP ViT. Here out_hidden_size is enlarged due to deepstack
|
# DP ViT. Here out_hidden_size is enlarged due to deepstack
|
||||||
@ -377,82 +378,68 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
return rotary_pos_emb
|
return rotary_pos_emb
|
||||||
|
|
||||||
def fast_pos_embed_interpolate(self, grid_thw):
|
def fast_pos_embed_interpolate(self,
|
||||||
num_grid_per_side = int(self.num_position_embeddings**0.5)
|
grid_thw: list[list[int]]) -> torch.Tensor:
|
||||||
|
|
||||||
idx_list = [[] for _ in range(4)]
|
num_grid_per_side = self.num_grid_per_side
|
||||||
weight_list = [[] for _ in range(4)]
|
m_size = self.spatial_merge_size
|
||||||
|
hidden_dim = self.pos_embed.embedding_dim
|
||||||
|
|
||||||
|
outputs = []
|
||||||
for t, h, w in grid_thw:
|
for t, h, w in grid_thw:
|
||||||
h_idxs = torch.linspace(0,
|
h_idxs = torch.linspace(0,
|
||||||
num_grid_per_side - 1,
|
num_grid_per_side - 1,
|
||||||
h,
|
h,
|
||||||
dtype=torch.float32)
|
dtype=torch.float32,
|
||||||
|
device=self.device)
|
||||||
w_idxs = torch.linspace(0,
|
w_idxs = torch.linspace(0,
|
||||||
num_grid_per_side - 1,
|
num_grid_per_side - 1,
|
||||||
w,
|
w,
|
||||||
dtype=torch.float32)
|
dtype=torch.float32,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
h_idxs_floor = h_idxs.to(torch.long)
|
h_floor = h_idxs.to(torch.long)
|
||||||
w_idxs_floor = w_idxs.to(torch.long)
|
w_floor = w_idxs.to(torch.long)
|
||||||
h_idxs_ceil = torch.clamp(h_idxs.to(torch.long) + 1,
|
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
|
||||||
max=num_grid_per_side - 1)
|
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
|
||||||
w_idxs_ceil = torch.clamp(w_idxs.to(torch.long) + 1,
|
|
||||||
max=num_grid_per_side - 1)
|
|
||||||
|
|
||||||
dh = h_idxs - h_idxs_floor
|
dh = h_idxs - h_floor
|
||||||
dw = w_idxs - w_idxs_floor
|
dw = w_idxs - w_floor
|
||||||
|
|
||||||
idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T +
|
w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1)
|
||||||
w_idxs_floor[None]).flatten().tolist() * t)
|
w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1)
|
||||||
idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T +
|
w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1)
|
||||||
w_idxs_ceil[None]).flatten().tolist() * t)
|
w11 = (dh[:, None] * dw[None, :]).reshape(-1)
|
||||||
idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
|
|
||||||
w_idxs_floor[None]).flatten().tolist() * t)
|
|
||||||
idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
|
|
||||||
w_idxs_ceil[None]).flatten().tolist() * t)
|
|
||||||
|
|
||||||
weight_list[0].extend(
|
idx00 = (h_floor[:, None] * num_grid_per_side +
|
||||||
((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t)
|
w_floor[None, :]).reshape(-1)
|
||||||
weight_list[1].extend(
|
idx01 = (h_floor[:, None] * num_grid_per_side +
|
||||||
((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
|
w_ceil[None, :]).reshape(-1)
|
||||||
weight_list[2].extend(
|
idx10 = (h_ceil[:, None] * num_grid_per_side +
|
||||||
(dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
|
w_floor[None, :]).reshape(-1)
|
||||||
weight_list[3].extend(
|
idx11 = (h_ceil[:, None] * num_grid_per_side +
|
||||||
(dh[None].T * dw[None]).flatten().tolist() * t)
|
w_ceil[None, :]).reshape(-1)
|
||||||
|
|
||||||
device = self.pos_embed.weight.device
|
indices = torch.stack([idx00, idx01, idx10, idx11], dim=0)
|
||||||
dtype = self.pos_embed.weight.dtype
|
weights = torch.stack([w00, w01, w10, w11],
|
||||||
|
dim=0).to(dtype=self.dtype,
|
||||||
|
device=self.device)
|
||||||
|
weights = weights.unsqueeze(-1)
|
||||||
|
|
||||||
p0 = self.pos_embed(
|
embeds = self.pos_embed(indices)
|
||||||
torch.tensor(
|
weighted_embeds = embeds * weights
|
||||||
idx_list[0], dtype=torch.long, device=device)) * torch.tensor(
|
p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
|
||||||
weight_list[0], dtype=dtype, device=device)[:, None]
|
combined = p0 + p1 + p2 + p3
|
||||||
p1 = self.pos_embed(
|
|
||||||
torch.tensor(
|
|
||||||
idx_list[1], dtype=torch.long, device=device)) * torch.tensor(
|
|
||||||
weight_list[1], dtype=dtype, device=device)[:, None]
|
|
||||||
p2 = self.pos_embed(
|
|
||||||
torch.tensor(
|
|
||||||
idx_list[2], dtype=torch.long, device=device)) * torch.tensor(
|
|
||||||
weight_list[2], dtype=dtype, device=device)[:, None]
|
|
||||||
p3 = self.pos_embed(
|
|
||||||
torch.tensor(
|
|
||||||
idx_list[3], dtype=torch.long, device=device)) * torch.tensor(
|
|
||||||
weight_list[3], dtype=dtype, device=device)[:, None]
|
|
||||||
|
|
||||||
patch_pos_embeds = p0 + p1 + p2 + p3
|
combined = combined.view(h * w, hidden_dim)
|
||||||
patch_pos_embeds = patch_pos_embeds.split(
|
repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
|
||||||
[t * h * w for t, h, w in grid_thw])
|
repeated = repeated.view(t, h // m_size, m_size, w // m_size,
|
||||||
patch_pos_embeds_permute = []
|
m_size, hidden_dim)
|
||||||
m_size = self.spatial_merge_size
|
repeated = repeated.permute(0, 1, 3, 2, 4,
|
||||||
for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
|
5).reshape(-1, hidden_dim)
|
||||||
pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size,
|
outputs.append(repeated)
|
||||||
m_size, -1).permute(0, 1, 3, 2, 4,
|
|
||||||
5).flatten(0, 4)
|
return torch.cat(outputs, dim=0)
|
||||||
patch_pos_embeds_permute.append(pos_embed)
|
|
||||||
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
|
|
||||||
return patch_pos_embeds
|
|
||||||
|
|
||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(
|
||||||
self,
|
self,
|
||||||
@ -477,12 +464,9 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
hidden_states = hidden_states + pos_embeds
|
hidden_states = hidden_states + pos_embeds
|
||||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||||
|
|
||||||
if isinstance(grid_thw, list):
|
grid_thw_tensor = torch.tensor(grid_thw,
|
||||||
grid_thw_tensor = torch.tensor(grid_thw,
|
device=self.device,
|
||||||
device=hidden_states.device,
|
dtype=torch.int32)
|
||||||
dtype=torch.int32)
|
|
||||||
else:
|
|
||||||
grid_thw_tensor = grid_thw
|
|
||||||
|
|
||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
|
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
|
||||||
@ -1224,7 +1208,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
grid_thw_list,
|
grid_thw_list,
|
||||||
rope_type="rope_3d")
|
rope_type="rope_3d")
|
||||||
else:
|
else:
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
image_embeds = self.visual(pixel_values,
|
||||||
|
grid_thw=grid_thw_list)
|
||||||
|
|
||||||
# Split concatenated embeddings for each image item.
|
# Split concatenated embeddings for each image item.
|
||||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
||||||
@ -1526,4 +1511,4 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
language_model="language_model",
|
language_model="language_model",
|
||||||
connector="model.visual.merger",
|
connector="model.visual.merger",
|
||||||
tower_model="model.visual.",
|
tower_model="model.visual.",
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user