mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[Model][Qwen3VL] Compute cu_seqlens on CPU to remove (#26496)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
910abdbd08
commit
b2155ed317
@ -488,7 +488,9 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1)
|
indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1)
|
||||||
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
|
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
|
||||||
weights = weights.to(dtype=self.dtype, device=self.device)
|
weights = weights.to(
|
||||||
|
dtype=self.dtype, device=self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
embeds = self.pos_embed(indices)
|
embeds = self.pos_embed(indices)
|
||||||
weighted_embeds = embeds * weights
|
weighted_embeds = embeds * weights
|
||||||
@ -524,14 +526,15 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
grid_thw: list[list[int]],
|
grid_thw: list[list[int]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
||||||
hidden_states = self.patch_embed(hidden_states)
|
hidden_states = self.patch_embed(hidden_states)
|
||||||
|
|
||||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
||||||
hidden_states = hidden_states + pos_embeds
|
hidden_states = hidden_states + pos_embeds
|
||||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||||
|
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True)
|
||||||
|
|
||||||
grid_thw_tensor = torch.tensor(grid_thw, device=self.device, dtype=torch.int32)
|
grid_thw_tensor = torch.tensor(grid_thw, dtype=torch.int32)
|
||||||
|
|
||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0]
|
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0]
|
||||||
@ -542,8 +545,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
|
||||||
hidden_states = hidden_states.unsqueeze(1)
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
|
|
||||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
|
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
deepstack_feature_lists = []
|
deepstack_feature_lists = []
|
||||||
for layer_num, blk in enumerate(self.blocks):
|
for layer_num, blk in enumerate(self.blocks):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user