[Model][Qwen3VL] Compute cu_seqlens on CPU to remove (#26496)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Lukas Geiger 2025-10-10 17:42:17 +01:00 committed by GitHub
parent 910abdbd08
commit b2155ed317
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -488,7 +488,9 @@ class Qwen3_VisionTransformer(nn.Module):
indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1) indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1)
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
weights = weights.to(dtype=self.dtype, device=self.device) weights = weights.to(
dtype=self.dtype, device=self.device, non_blocking=True
)
embeds = self.pos_embed(indices) embeds = self.pos_embed(indices)
weighted_embeds = embeds * weights weighted_embeds = embeds * weights
@ -524,14 +526,15 @@ class Qwen3_VisionTransformer(nn.Module):
x: torch.Tensor, x: torch.Tensor,
grid_thw: list[list[int]], grid_thw: list[list[int]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype) hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
hidden_states = self.patch_embed(hidden_states) hidden_states = self.patch_embed(hidden_states)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw) pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw) rotary_pos_emb = self.rot_pos_emb(grid_thw)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True)
grid_thw_tensor = torch.tensor(grid_thw, device=self.device, dtype=torch.int32) grid_thw_tensor = torch.tensor(grid_thw, dtype=torch.int32)
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.repeat_interleave(
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0] grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0]
@ -542,8 +545,8 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
hidden_states = hidden_states.unsqueeze(1) hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
deepstack_feature_lists = [] deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks): for layer_num, blk in enumerate(self.blocks):