diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 5904ad1f1f247..68dd07820189e 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -25,7 +25,7 @@ # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping -from functools import partial +from functools import lru_cache, partial from typing import Callable, Literal, Optional, TypedDict, Union import torch @@ -478,8 +478,8 @@ class Qwen2_5_VisionRotaryEmbedding(nn.Module): super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / (theta**( + torch.arange(0, dim, 2, dtype=torch.float, device='cpu') / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -520,7 +520,7 @@ class Qwen2_5_VisionTransformer(nn.Module): self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads - # args for get_window_index + # args for get_window_index_thw self.window_size = vision_config.window_size self.patch_size = vision_config.patch_size self.spatial_merge_size = vision_config.spatial_merge_size @@ -567,65 +567,71 @@ class Qwen2_5_VisionTransformer(nn.Module): def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + def rotary_pos_emb_thw(self, t, h, w): + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) + max_size = max(h, w) + rotary_pos_emb_full = self.rotary_pos_emb(max_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + rotary_pos_emb = rotary_pos_emb.reshape( + rotary_pos_emb.shape[0] // self.spatial_merge_unit, + self.spatial_merge_unit, -1) + return rotary_pos_emb - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 + def get_window_index_thw(self, grid_t, grid_h, grid_w): vit_merger_window_size = (self.window_size // self.spatial_merge_size // self.patch_size) - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h = grid_h // self.spatial_merge_size - llm_grid_w = grid_w // self.spatial_merge_size - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) - index_padded = index_padded.reshape(grid_t, num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, num_windows_h * num_windows_w, vit_merger_window_size, - vit_merger_window_size) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum( - 0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) - return window_index, cu_window_seqlens + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) + index_padded = index_padded.reshape(grid_t, num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, + vit_merger_window_size) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_seqlens_tmp = cu_seqlens_tmp.to(dtype=torch.int32) + cu_seqlens_tmp = torch.unique_consecutive(cu_seqlens_tmp) + + return index_new, cu_seqlens_tmp + + @lru_cache(maxsize=1024) # noqa: B019 + def get_rope_by_thw(self, t, h, w): + window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw( + t, h, w) + rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w) + rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :] + rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1) + cu_seqlens_thw = torch.repeat_interleave( + torch.tensor([h * w], dtype=torch.int32), t) + return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, + cu_seqlens_thw) def compute_attn_mask_seqlen( self, @@ -641,45 +647,74 @@ class Qwen2_5_VisionTransformer(nn.Module): def forward( self, x: torch.Tensor, - grid_thw: torch.Tensor, + grid_thw: list[list[int]], ) -> torch.Tensor: # patchify + seq_len, _ = x.size() + rotary_pos_emb = [] + window_index: list = [] + cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)] + cu_seqlens: list = [] + hidden_states = x.to(device=self.device, dtype=self.dtype) hidden_states = self.patch_embed(hidden_states) - # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index_id = 0 + cu_window_seqlens_last = 0 + for t, h, w in grid_thw: + t, h, w = int(t), int(h), int(w) + llm_h = h // self.spatial_merge_size + llm_w = w // self.spatial_merge_size - # windows attention - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) + ( + rotary_pos_emb_thw, + window_index_thw, + cu_seqlens_window_thw, + cu_seqlens_thw, + ) = self.get_rope_by_thw(t, h, w) + + window_index.append(window_index_thw + window_index_id) + window_index_id += (t * llm_h * llm_w) + + cu_seqlens_window_thw = (cu_seqlens_window_thw + + cu_window_seqlens_last) + cu_window_seqlens_last = cu_seqlens_window_thw[-1] + cu_window_seqlens.append(cu_seqlens_window_thw) + + rotary_pos_emb.append(rotary_pos_emb_thw) + + cu_seqlens.append(cu_seqlens_thw) + + rotary_pos_emb = torch.cat(rotary_pos_emb) + window_index = torch.cat(window_index) + cu_window_seqlens = torch.cat(cu_window_seqlens) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - seq_len, _ = hidden_states.size() - hidden_states = hidden_states.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - hidden_states = hidden_states[window_index, :, :] - hidden_states = hidden_states.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32) + cu_seqlens = torch.cat(cu_seqlens) + cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers - hidden_states = hidden_states.unsqueeze(1) - # pre-compute seqlens for window/full attn to reduce cuMemcpy operations max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen( cu_seqlens) max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( cu_window_seqlens) + + cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) + cu_window_seqlens = cu_window_seqlens.to(device=self.device, + non_blocking=True) + rotary_pos_emb = rotary_pos_emb.to(device=self.device, + non_blocking=True) + window_index = window_index.to(device=hidden_states.device, + non_blocking=True) + + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + + hidden_states = hidden_states.unsqueeze(1) + for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens @@ -932,12 +967,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size @@ -951,13 +987,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size