diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index e10b8e1e77af1..e1b7dbf99f1fd 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -128,12 +128,7 @@ def batch_make_image_embeddings( visual = model.visual pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype) - image_grid_thw_on_device = image_grid_thw.to( - visual.device, dtype=torch.int64 - ) - return visual( - pixel_values_on_device, grid_thw=image_grid_thw_on_device - ).cpu() + return visual(pixel_values_on_device, grid_thw=image_grid_thw).cpu() image_embeds = torch.concat(llm.apply_model(get_image_embeds)) @@ -217,12 +212,7 @@ def batch_make_video_embeddings( visual = model.visual pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype) - video_grid_thw_on_device = video_grid_thw.to( - visual.device, dtype=torch.int64 - ) - return visual( - pixel_values_on_device, grid_thw=video_grid_thw_on_device - ).cpu() + return visual(pixel_values_on_device, grid_thw=video_grid_thw).cpu() video_embeds = torch.concat(llm.apply_model(get_image_embeds)) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d25ff2785bfef..479a7871e364f 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -29,6 +29,7 @@ from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from typing import Annotated, Any, Literal, TypeAlias +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -751,25 +752,27 @@ class Qwen2VisionTransformer(nn.Module): if isinstance(grid_thw, list): grid_thw_list = grid_thw - grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + grid_thw = np.array(grid_thw, dtype=np.int32) else: grid_thw_list = grid_thw.tolist() + grid_thw = grid_thw.numpy() # compute position embedding rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] - ).cumsum(dim=0, dtype=torch.int32) - cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) - cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) + cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + axis=0, dtype=np.int32 + ) + cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) + cu_seqlens = torch.from_numpy(cu_seqlens) # transformers x = x.unsqueeze(1) # pre-compute seqlens for attn mask to reduce cuMemcpy operations max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) for blk in self.blocks: x = blk( x, diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index c10aeaec5ab83..90c4894d33e88 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -553,18 +553,20 @@ class Qwen3_VisionTransformer(nn.Module): if isinstance(grid_thw, list): grid_thw_list = grid_thw - grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + grid_thw = np.array(grid_thw, dtype=np.int32) else: grid_thw_list = grid_thw.tolist() + grid_thw = grid_thw.numpy() pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) hidden_states = hidden_states + pos_embeds rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] - ).cumsum(dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) - cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) + cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + axis=0, dtype=np.int32 + ) + cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) + cu_seqlens = torch.from_numpy(cu_seqlens) hidden_states = hidden_states.unsqueeze(1) max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)