[MM][Perf] Minor Optimization on Qwen3-VL fast_pos_embed_interpolate (#25337)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang 2025-09-21 04:05:20 -07:00 committed by GitHub
parent cf56cf78b4
commit 30d08911f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -270,6 +270,7 @@ class Qwen3_VisionTransformer(nn.Module):
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.use_data_parallel = use_data_parallel
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
# NOTE: This is used for creating empty tensor for all_gather for
# DP ViT. Here out_hidden_size is enlarged due to deepstack
@ -377,82 +378,68 @@ class Qwen3_VisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def fast_pos_embed_interpolate(self, grid_thw):
num_grid_per_side = int(self.num_position_embeddings**0.5)
def fast_pos_embed_interpolate(self,
grid_thw: list[list[int]]) -> torch.Tensor:
idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
num_grid_per_side = self.num_grid_per_side
m_size = self.spatial_merge_size
hidden_dim = self.pos_embed.embedding_dim
outputs = []
for t, h, w in grid_thw:
h_idxs = torch.linspace(0,
num_grid_per_side - 1,
h,
dtype=torch.float32)
dtype=torch.float32,
device=self.device)
w_idxs = torch.linspace(0,
num_grid_per_side - 1,
w,
dtype=torch.float32)
dtype=torch.float32,
device=self.device)
h_idxs_floor = h_idxs.to(torch.long)
w_idxs_floor = w_idxs.to(torch.long)
h_idxs_ceil = torch.clamp(h_idxs.to(torch.long) + 1,
max=num_grid_per_side - 1)
w_idxs_ceil = torch.clamp(w_idxs.to(torch.long) + 1,
max=num_grid_per_side - 1)
h_floor = h_idxs.to(torch.long)
w_floor = w_idxs.to(torch.long)
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
dh = h_idxs - h_floor
dw = w_idxs - w_floor
idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T +
w_idxs_floor[None]).flatten().tolist() * t)
idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T +
w_idxs_ceil[None]).flatten().tolist() * t)
idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
w_idxs_floor[None]).flatten().tolist() * t)
idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
w_idxs_ceil[None]).flatten().tolist() * t)
w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1)
w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1)
w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1)
w11 = (dh[:, None] * dw[None, :]).reshape(-1)
weight_list[0].extend(
((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t)
weight_list[1].extend(
((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
weight_list[2].extend(
(dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
weight_list[3].extend(
(dh[None].T * dw[None]).flatten().tolist() * t)
idx00 = (h_floor[:, None] * num_grid_per_side +
w_floor[None, :]).reshape(-1)
idx01 = (h_floor[:, None] * num_grid_per_side +
w_ceil[None, :]).reshape(-1)
idx10 = (h_ceil[:, None] * num_grid_per_side +
w_floor[None, :]).reshape(-1)
idx11 = (h_ceil[:, None] * num_grid_per_side +
w_ceil[None, :]).reshape(-1)
device = self.pos_embed.weight.device
dtype = self.pos_embed.weight.dtype
indices = torch.stack([idx00, idx01, idx10, idx11], dim=0)
weights = torch.stack([w00, w01, w10, w11],
dim=0).to(dtype=self.dtype,
device=self.device)
weights = weights.unsqueeze(-1)
p0 = self.pos_embed(
torch.tensor(
idx_list[0], dtype=torch.long, device=device)) * torch.tensor(
weight_list[0], dtype=dtype, device=device)[:, None]
p1 = self.pos_embed(
torch.tensor(
idx_list[1], dtype=torch.long, device=device)) * torch.tensor(
weight_list[1], dtype=dtype, device=device)[:, None]
p2 = self.pos_embed(
torch.tensor(
idx_list[2], dtype=torch.long, device=device)) * torch.tensor(
weight_list[2], dtype=dtype, device=device)[:, None]
p3 = self.pos_embed(
torch.tensor(
idx_list[3], dtype=torch.long, device=device)) * torch.tensor(
weight_list[3], dtype=dtype, device=device)[:, None]
embeds = self.pos_embed(indices)
weighted_embeds = embeds * weights
p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
combined = p0 + p1 + p2 + p3
patch_pos_embeds = p0 + p1 + p2 + p3
patch_pos_embeds = patch_pos_embeds.split(
[t * h * w for t, h, w in grid_thw])
patch_pos_embeds_permute = []
m_size = self.spatial_merge_size
for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size,
m_size, -1).permute(0, 1, 3, 2, 4,
5).flatten(0, 4)
patch_pos_embeds_permute.append(pos_embed)
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
return patch_pos_embeds
combined = combined.view(h * w, hidden_dim)
repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
repeated = repeated.view(t, h // m_size, m_size, w // m_size,
m_size, hidden_dim)
repeated = repeated.permute(0, 1, 3, 2, 4,
5).reshape(-1, hidden_dim)
outputs.append(repeated)
return torch.cat(outputs, dim=0)
def compute_attn_mask_seqlen(
self,
@ -477,12 +464,9 @@ class Qwen3_VisionTransformer(nn.Module):
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
if isinstance(grid_thw, list):
grid_thw_tensor = torch.tensor(grid_thw,
device=hidden_states.device,
dtype=torch.int32)
else:
grid_thw_tensor = grid_thw
grid_thw_tensor = torch.tensor(grid_thw,
device=self.device,
dtype=torch.int32)
cu_seqlens = torch.repeat_interleave(
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
@ -1224,7 +1208,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
grid_thw_list,
rope_type="rope_3d")
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
image_embeds = self.visual(pixel_values,
grid_thw=grid_thw_list)
# Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
@ -1526,4 +1511,4 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
language_model="language_model",
connector="model.visual.merger",
tower_model="model.visual.",
)
)