mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 02:17:02 +08:00
[Model][QwenVL] Replace torch.repeat_interleave with faster np.repeat (#28964)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
64192d5624
commit
a9705a290a
@ -128,12 +128,7 @@ def batch_make_image_embeddings(
|
||||
visual = model.visual
|
||||
|
||||
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
|
||||
image_grid_thw_on_device = image_grid_thw.to(
|
||||
visual.device, dtype=torch.int64
|
||||
)
|
||||
return visual(
|
||||
pixel_values_on_device, grid_thw=image_grid_thw_on_device
|
||||
).cpu()
|
||||
return visual(pixel_values_on_device, grid_thw=image_grid_thw).cpu()
|
||||
|
||||
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
||||
|
||||
@ -217,12 +212,7 @@ def batch_make_video_embeddings(
|
||||
visual = model.visual
|
||||
|
||||
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
|
||||
video_grid_thw_on_device = video_grid_thw.to(
|
||||
visual.device, dtype=torch.int64
|
||||
)
|
||||
return visual(
|
||||
pixel_values_on_device, grid_thw=video_grid_thw_on_device
|
||||
).cpu()
|
||||
return visual(pixel_values_on_device, grid_thw=video_grid_thw).cpu()
|
||||
|
||||
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -751,25 +752,27 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw_list = grid_thw
|
||||
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
|
||||
grid_thw = np.array(grid_thw, dtype=np.int32)
|
||||
else:
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
grid_thw = grid_thw.numpy()
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
|
||||
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
axis=0, dtype=np.int32
|
||||
)
|
||||
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||
cu_seqlens = torch.from_numpy(cu_seqlens)
|
||||
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
|
||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||
for blk in self.blocks:
|
||||
x = blk(
|
||||
x,
|
||||
|
||||
@ -553,18 +553,20 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw_list = grid_thw
|
||||
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
|
||||
grid_thw = np.array(grid_thw, dtype=np.int32)
|
||||
else:
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
grid_thw = grid_thw.numpy()
|
||||
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
|
||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
axis=0, dtype=np.int32
|
||||
)
|
||||
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||
cu_seqlens = torch.from_numpy(cu_seqlens)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user