From b2155ed3175f203c9ef7de66636ffbfde1d541b7 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Fri, 10 Oct 2025 17:42:17 +0100 Subject: [PATCH] [Model][Qwen3VL] Compute `cu_seqlens` on CPU to remove (#26496) Signed-off-by: Lukas Geiger Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm/model_executor/models/qwen3_vl.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f7ba06d97f01..8862e88bd531 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -488,7 +488,9 @@ class Qwen3_VisionTransformer(nn.Module): indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1) weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) - weights = weights.to(dtype=self.dtype, device=self.device) + weights = weights.to( + dtype=self.dtype, device=self.device, non_blocking=True + ) embeds = self.pos_embed(indices) weighted_embeds = embeds * weights @@ -524,14 +526,15 @@ class Qwen3_VisionTransformer(nn.Module): x: torch.Tensor, grid_thw: list[list[int]], ) -> torch.Tensor: - hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True) hidden_states = self.patch_embed(hidden_states) pos_embeds = self.fast_pos_embed_interpolate(grid_thw) hidden_states = hidden_states + pos_embeds rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True) - grid_thw_tensor = torch.tensor(grid_thw, device=self.device, dtype=torch.int32) + grid_thw_tensor = torch.tensor(grid_thw, dtype=torch.int32) cu_seqlens = torch.repeat_interleave( grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0] @@ -542,8 +545,8 @@ class Qwen3_VisionTransformer(nn.Module): cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) hidden_states = hidden_states.unsqueeze(1) - rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks):