[Model][QwenVL] Replace torch.repeat_interleave with faster np.repeat (#28964)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-11-20 06:04:23 +00:00 committed by GitHub
parent 64192d5624
commit a9705a290a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 23 deletions

View File

@ -128,12 +128,7 @@ def batch_make_image_embeddings(
visual = model.visual
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
image_grid_thw_on_device = image_grid_thw.to(
visual.device, dtype=torch.int64
)
return visual(
pixel_values_on_device, grid_thw=image_grid_thw_on_device
).cpu()
return visual(pixel_values_on_device, grid_thw=image_grid_thw).cpu()
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
@ -217,12 +212,7 @@ def batch_make_video_embeddings(
visual = model.visual
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
video_grid_thw_on_device = video_grid_thw.to(
visual.device, dtype=torch.int64
)
return visual(
pixel_values_on_device, grid_thw=video_grid_thw_on_device
).cpu()
return visual(pixel_values_on_device, grid_thw=video_grid_thw).cpu()
video_embeds = torch.concat(llm.apply_model(get_image_embeds))

View File

@ -29,6 +29,7 @@ from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, TypeAlias
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -751,25 +752,27 @@ class Qwen2VisionTransformer(nn.Module):
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
grid_thw = np.array(grid_thw, dtype=np.int32)
else:
grid_thw_list = grid_thw.tolist()
grid_thw = grid_thw.numpy()
# compute position embedding
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
cu_seqlens = torch.from_numpy(cu_seqlens)
# transformers
x = x.unsqueeze(1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
for blk in self.blocks:
x = blk(
x,

View File

@ -553,18 +553,20 @@ class Qwen3_VisionTransformer(nn.Module):
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
grid_thw = np.array(grid_thw, dtype=np.int32)
else:
grid_thw_list = grid_thw.tolist()
grid_thw = grid_thw.numpy()
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
cu_seqlens = torch.from_numpy(cu_seqlens)
hidden_states = hidden_states.unsqueeze(1)
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)