From 727144bed10ffd465e37d47a5a60747efc15368b Mon Sep 17 00:00:00 2001 From: dsinghvi Date: Sat, 11 Oct 2025 12:51:04 +0530 Subject: [PATCH] [Refactor]: Use M-RoPE interface directly while defining model class instead of maintaining model specific M-RoPE implementation in mrope.py (#24172) Signed-off-by: Divyansh Singhvi Signed-off-by: dsinghvi Signed-off-by: DarkLight1337 Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: DarkLight1337 Co-authored-by: wwl2755 --- .../layers/rotary_embedding/mrope.py | 1015 ----------------- vllm/model_executor/models/ernie45_vl.py | 151 ++- vllm/model_executor/models/glm4v.py | 152 ++- vllm/model_executor/models/keye_vl1_5.py | 144 ++- .../models/qwen2_5_omni_thinker.py | 271 ++++- vllm/model_executor/models/qwen2_5_vl.py | 130 ++- vllm/model_executor/models/qwen3_vl.py | 115 +- vllm/model_executor/models/utils.py | 8 + vllm/v1/worker/gpu_model_runner.py | 39 +- 9 files changed, 974 insertions(+), 1051 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index ebfe9257c6c4..fce110e6a527 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import itertools from typing import Optional, Union import numpy as np @@ -411,969 +410,6 @@ class MRotaryEmbedding(RotaryEmbedding): return llm_positions.tolist(), mrope_position_delta - @classmethod - def get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> tuple[torch.Tensor, int]: - from vllm.transformers_utils.config import thinker_uses_mrope - - if thinker_uses_mrope(hf_config) and hf_config.model_type == "qwen2_5_omni": - return cls._omni_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - elif hf_config.model_type in ["glm4v", "glm4v_moe"]: - return cls._glm4v_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - context_len=context_len, - seq_len=seq_len, - ) - elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]: - return cls._qwen3vl_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - context_len=context_len, - seq_len=seq_len, - ) - elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: - return cls._ernie_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - context_len=context_len, - seq_len=seq_len, - ) - elif "KeyeVL1_5" in hf_config.model_type: - return cls._keye_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - context_len=context_len, - seq_len=seq_len, - ) - else: - return cls._vl_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - ) - - @classmethod - def _glm4v_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for GLM4V.""" - - image_token_id = hf_config.image_token_id - video_start_token_id = hf_config.video_start_token_id - video_end_token_id = hf_config.video_end_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - llm_pos_ids_list: list = [] - - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - - input_token_type: list[str] = [] - video_check_flg = False - for token in input_tokens: - if token == video_start_token_id: - video_check_flg = True - elif token == video_end_token_id: - video_check_flg = False - - if (token == image_token_id) and (video_check_flg is False): - input_token_type.append("image") - elif (token == image_token_id) and (video_check_flg is True): - input_token_type.append("video") - else: - input_token_type.append("text") - - input_type_group: list[tuple[str, int, int]] = [] - for key, group_iter in itertools.groupby( - enumerate(input_token_type), lambda x: x[1] - ): - group_list = list(group_iter) - start_index = group_list[0][0] - end_index = group_list[-1][0] + 1 - input_type_group.append((key, start_index, end_index)) - - video_frame_num = 1 - mm_data_idx = 0 - for modality_type, start_idx, end_idx in input_type_group: - st_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - ) - if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx - ) - mm_data_idx += 1 - - elif modality_type == "video": - t, h, w = ( - video_frame_num, - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - - for t_idx in range(llm_grid_t): - t_index = ( - torch.tensor(t_idx) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(1, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(1, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx - ) - - mm_data_idx += 1 - video_frame_num += 1 - - else: - text_len = end_idx - start_idx - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - video_frame_num = 1 - - else: - text_len = len(input_tokens) - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - return llm_positions, mrope_position_delta - - @classmethod - def _qwen3vl_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" - - video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] - - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - vision_start_token_id = hf_config.vision_start_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - - input_tokens_tensor = torch.tensor(input_tokens) - vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id - ).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - llm_pos_ids_list: list = [] - - st = 0 - remain_images, remain_videos = image_nums, video_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_index += 1 - remain_videos -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] - return llm_positions, mrope_position_delta - - @classmethod - def _ernie_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for Ernie VL.""" - - image_token_id = hf_config.im_patch_id - video_start_token_id = hf_config.video_start_token_id - video_end_token_id = hf_config.video_end_token_id - spatial_conv_size = hf_config.spatial_conv_size - temporal_conv_size = hf_config.temporal_conv_size - llm_pos_ids_list: list = [] - - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - - input_token_type: list[str] = [] - video_check_flg = False - for token in input_tokens: - if token == video_start_token_id: - video_check_flg = True - elif token == video_end_token_id: - video_check_flg = False - - if (token == image_token_id) and (video_check_flg is False): - input_token_type.append("image") - elif (token == image_token_id) and (video_check_flg is True): - input_token_type.append("video") - else: - input_token_type.append("text") - - input_type_group: list[tuple[str, int, int]] = [] - for key, group_iter in itertools.groupby( - enumerate(input_token_type), lambda x: x[1] - ): - group_list = list(group_iter) - start_index = group_list[0][0] - end_index = group_list[-1][0] + 1 - input_type_group.append((key, start_index, end_index)) - - video_frame_num = 1 - mm_data_idx = 0 - for modality_type, start_idx, end_idx in input_type_group: - st_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - ) - if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_conv_size, - w // spatial_conv_size, - ) - - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx - ) - mm_data_idx += 1 - - elif modality_type == "video": - t, h, w = ( - video_grid_thw[mm_data_idx][0], - video_grid_thw[mm_data_idx][1], - video_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = ( - t // temporal_conv_size, - h // spatial_conv_size, - w // spatial_conv_size, - ) - - for t_idx in range(llm_grid_t): - t_index = ( - torch.tensor(t_idx) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(1, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(1, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx - ) - - mm_data_idx += 1 - video_frame_num += 1 - - else: - text_len = end_idx - start_idx - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - video_frame_num = 1 - - else: - text_len = len(input_tokens) - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - return llm_positions, mrope_position_delta - - @classmethod - def _keye_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: - video_grid_thw = video_grid_thw[0] - """Get mrope input positions and delta value (Keye series).""" - - def split_thw(grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]: - """ - Split grid_thw along the t dimension. - - Args: - grid_thw: shape [N, 3] tensor or nested list of [t, h, w]. - - Returns: - List of [1, h, w] rows, repeated t times for each original row. - """ - - if isinstance(grid_thw, list): - grid_thw = torch.tensor(grid_thw, dtype=torch.long) - - if grid_thw.numel() == 0: - return [] - - t, hw = grid_thw[:, 0], grid_thw[:, 1:] - ones = torch.ones_like(hw[:, :1]) # [N,1] - out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) - return out.tolist() - - video_grid_thw = split_thw(video_grid_thw) - - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - - image_nums = len(image_grid_thw) - frame_nums = len(video_grid_thw) - llm_pos_ids_list: list = [] - - st = 0 - remain_images, remain_frames = image_nums, frame_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + frame_nums): - if remain_images > 0: - try: - ed_image = input_tokens.index(image_token_id, st) - except ValueError: - ed_image = len(input_tokens) + 1 - else: - ed_image = len(input_tokens) + 1 - if remain_frames > 0: - try: - ed_video = input_tokens.index(video_token_id, st) - except ValueError: - ed_video = len(input_tokens) + 1 - else: - ed_video = len(input_tokens) + 1 - - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_index += 1 - remain_frames -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - t_index = ( - torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) - ).flatten() - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @classmethod - def _vl_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" - - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - vision_start_token_id = hf_config.vision_start_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) - - input_tokens_tensor = torch.tensor(input_tokens) - vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id - ).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - llm_pos_ids_list: list = [] - - st = 0 - remain_images, remain_videos = image_nums, video_nums - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - video_second_per_grid_t = 0.0 - if remain_images > 0: - try: - ed_image = input_tokens.index(image_token_id, st) - except ValueError: - ed_image = len(input_tokens) + 1 - else: - ed_image = len(input_tokens) + 1 - if remain_videos > 0: - try: - ed_video = input_tokens.index(video_token_id, st) - except ValueError: - ed_video = len(input_tokens) + 1 - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_second_per_grid_t = 1.0 - if second_per_grid_ts: - video_second_per_grid_t = second_per_grid_ts[video_index] - video_index += 1 - remain_videos -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - t_index = ( - torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) - * video_second_per_grid_t - * tokens_per_second - ).flatten() - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @classmethod - def _omni_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: Optional[list[float]] = None, - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value (Qwen2.5-Omni version). - - Differences from MRotaryEmbedding: - 1. Add audio support (and related `audio_feature_lengths`). - 2. Add `use_audio_in_video` option to read audio from video inputs. - In this case, audio and vision position ids will be split into - chunks and interleaved. - - Example: - - (V_i are vision position ids, A_i are audio position ids) - - |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... - |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... - """ - - # TODO(fyabc): refactor and share more code with - # _vl_get_input_positions_tensor. - thinker_config = hf_config.thinker_config - - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) - - audio_token_id = thinker_config.audio_token_index - image_token_id = thinker_config.image_token_index - video_token_id = thinker_config.video_token_index - audio_start_token_id = thinker_config.audio_start_token_id - audio_end_token_id = thinker_config.audio_end_token_id - vision_start_token_id = thinker_config.vision_start_token_id - vision_end_token_id = thinker_config.vision_end_token_id - seconds_per_chunk = thinker_config.seconds_per_chunk - spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr( - thinker_config.vision_config, "tokens_per_second", 25 - ) - - src_item = input_tokens - audio_seqlens = audio_feature_lengths - if not second_per_grid_ts: - second_per_grid_ts = [1] * video_grid_thw.shape[0] - audio_idx = 0 - video_idx = 0 - image_idx = 0 - new_src_item: list[int] = [] - llm_pos_ids_list: list[torch.Tensor] = [] - - idx = 0 - while idx < len(src_item): - new_src_item_len = len(new_src_item) - start_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - ) - if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]: - if use_audio_in_video and idx > 0: - if ( - src_item[idx] == vision_end_token_id - and src_item[idx - 1] == audio_end_token_id - ): - # processing the <|audio_eos|> before <|vision_eos|> - start_idx -= 1 - elif ( - src_item[idx] == audio_start_token_id - and src_item[idx - 1] == vision_start_token_id - ): - # processing the <|audio_bos|> after <|vision_eos|> - start_idx -= 1 - new_src_item.append(src_item[idx]) - llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) - llm_pos_ids_list.append(llm_pos_ids) - elif src_item[idx] == audio_token_id: - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 - new_src_item.extend([audio_token_id] * place_num) - llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx - llm_pos_ids_list.append(llm_pos_ids) - audio_idx += 1 - elif src_item[idx] == image_token_id: - grid_t = image_grid_thw[image_idx][0] - grid_hs = image_grid_thw[:, 1] - grid_ws = image_grid_thw[:, 2] - t_index = torch.arange(grid_t) * 1 * tokens_per_second - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = image_grid_thw[image_idx].prod() // ( - spatial_merge_size**2 - ) - new_src_item.extend([image_token_id] * vision_seqlen) - image_idx += 1 - elif src_item[idx] == video_token_id and not use_audio_in_video: - grid_t = video_grid_thw[video_idx][0] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_index = ( - torch.arange(grid_t) - * second_per_grid_ts[video_idx] - * tokens_per_second - ) - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2 - ) - new_src_item.extend([video_token_id] * vision_seqlen) - video_idx += 1 - else: - # read audio from video - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2 - ) - grid_t = video_grid_thw[video_idx][0] - grid_h = video_grid_thw[video_idx][1] - grid_w = video_grid_thw[video_idx][2] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = ( - torch.arange(grid_t) - * second_per_grid_ts[video_idx] - * tokens_per_second - ) - t_index_split_chunk = cls._split_list_into_ranges( - t_index, t_ntoken_per_chunk - ) - place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 - pure_audio_len = place_num - 2 - added_audio_len = 0 - audio_llm_pos_ids_list: list[torch.Tensor] = [] - for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = ( - len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) - ) - new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) - vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( - start_idx, - video_idx, - spatial_merge_size, - t_chunk, - grid_hs, - grid_ws, - ).split(1, dim=1) - llm_pos_ids_list.extend(vision_llm_pos_ids_list) - new_src_item.extend( - min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) - * [audio_token_id] - ) - audio_start_idx = ( - start_idx - if len(audio_llm_pos_ids_list) == 0 - else audio_llm_pos_ids_list[-1][0].item() + 1 - ) - if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: - audio_llm_pos_ids_list = ( - torch.arange( - min( - t_ntoken_per_chunk, pure_audio_len - added_audio_len - ) - ).expand(3, -1) - + audio_start_idx - ).split(1, dim=1) - else: - audio_llm_pos_ids_list = [] - added_audio_len += min( - t_ntoken_per_chunk, pure_audio_len - added_audio_len - ) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - if added_audio_len < pure_audio_len: - new_src_item.extend( - (pure_audio_len - added_audio_len) * [audio_token_id] - ) - audio_llm_pos_ids_list = ( - torch.arange(pure_audio_len - added_audio_len).expand(3, -1) - + llm_pos_ids_list[-1].max() - + 1 - ).split(1, dim=1) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - audio_idx += 1 - video_idx += 1 - # move to the next token - idx += len(new_src_item) - new_src_item_len - - llm_positions = torch.cat(llm_pos_ids_list, dim=1) - mrope_position_delta = ( - torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) - ) - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @staticmethod - def _get_llm_pos_ids_for_vision( - start_idx: int, - vision_idx: int, - spatial_merge_size: int, - t_index: list[int], - grid_hs: torch.Tensor, - grid_ws: torch.Tensor, - ) -> torch.Tensor: - llm_pos_ids_list = [] - llm_grid_h = grid_hs[vision_idx] // spatial_merge_size - llm_grid_w = grid_ws[vision_idx] // spatial_merge_size - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(len(t_index), -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(len(t_index), llm_grid_h, -1) - .flatten() - ) - t_index_tensor = ( - torch.Tensor(t_index) - .to(llm_grid_h.device) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .long() - .flatten() - ) - _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) - llm_pos_ids_list.append(_llm_pos_ids + start_idx) - llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) - return llm_pos_ids - - @staticmethod - def _split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]: - ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)] - for num in lst: - index = num // interval - ranges[index].append(num) - return ranges - @staticmethod def get_next_input_positions( mrope_position_delta: int, @@ -1403,54 +439,3 @@ class MRotaryEmbedding(RotaryEmbedding): dtype=out.dtype, ) out[:, out_offset : out_offset + num_new_tokens] = values - - @classmethod - def omni_get_updates_use_audio_in_video( - cls, - thinker_config: PretrainedConfig, - audio_len: int, - video_grid_thw: Union[list[int], torch.Tensor], - video_second_per_grid_t: float, - ) -> list[int]: - """Get video prompt updates when `use_audio_in_video` is True. - - In this case, audio and vision update ids will be split into - chunks and interleaved (details in `_omni_get_input_positions_tensor`). - - <|video_bos|><|VIDEO|><|video_eos|> => - <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> - """ - - audio_token_id = thinker_config.audio_token_index - video_token_id = thinker_config.video_token_index - audio_start_token_id = thinker_config.audio_start_token_id - audio_end_token_id = thinker_config.audio_end_token_id - seconds_per_chunk = thinker_config.seconds_per_chunk - spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr( - thinker_config.vision_config, "tokens_per_second", 25 - ) - - grid_t = video_grid_thw[0] - grid_h = video_grid_thw[1] - grid_w = video_grid_thw[2] - t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second - t_index_split_chunk = cls._split_list_into_ranges(t_index, t_ntoken_per_chunk) - - updates = [audio_start_token_id] - added_audio_len = 0 - for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = ( - len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) - ) - updates.extend([video_token_id] * vision_ntoken_per_chunk) - - audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len) - updates.extend(audio_chunk_size * [audio_token_id]) - added_audio_len += audio_chunk_size - if added_audio_len < audio_len: - updates.extend((audio_len - added_audio_len) * [audio_token_id]) - updates.extend([audio_end_token_id]) - - return updates diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 2579a0ebf53e..d5b2caa2ddfd 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -23,6 +23,7 @@ # limitations under the License. """Inference-only Erine VL model compatible with HuggingFace weights.""" +import itertools import math from collections.abc import Iterable, Mapping, Sequence from functools import partial @@ -33,7 +34,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import BatchFeature +from transformers import BatchFeature, PretrainedConfig from vllm.attention.backends.registry import _Backend from vllm.attention.layer import ( @@ -76,6 +77,7 @@ from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsPP, ) @@ -1271,7 +1273,7 @@ class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessing dummy_inputs=Ernie4_5_VLDummyInputsBuilder, ) class Ernie4_5_VLMoeForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): merge_by_field_config = True @@ -1388,6 +1390,151 @@ class Ernie4_5_VLMoeForConditionalGeneration( else: self.visual_token_mask = None + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + second_per_grid_ts: Optional[list[float]] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for Ernie VL.""" + + image_token_id = hf_config.im_patch_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1] + ): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_conv_size, + w // spatial_conv_size, + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_grid_thw[mm_data_idx][0], + video_grid_thw[mm_data_idx][1], + video_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t // temporal_conv_size, + h // spatial_conv_size, + w // spatial_conv_size, + ) + + for t_idx in range(llm_grid_t): + t_index = ( + torch.tensor(t_idx) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(1, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(1, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return llm_positions, mrope_position_delta + def get_language_model(self) -> torch.nn.Module: return self.language_model diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index a5c3ce0e6bf7..63731b2947d2 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -5,6 +5,7 @@ # https://github.com/zai-org/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" +import itertools from argparse import Namespace from collections.abc import Mapping, Sequence from typing import Annotated, Literal, Optional, Union @@ -14,7 +15,7 @@ from torch import nn from torch.nn import LayerNorm from torchvision import transforms from torchvision.transforms import InterpolationMode -from transformers import BatchFeature, PreTrainedTokenizer, TensorType +from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput @@ -54,6 +55,7 @@ from .chatglm import ChatGLMBaseModel, ChatGLMModel from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsPP, ) @@ -554,7 +556,9 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): info=GLM4VProcessingInfo, dummy_inputs=GLM4VDummyInputsBuilder, ) -class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP): +class GLM4VForCausalLM( + ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE +): merge_by_field_config = True packed_modules_mapping = { @@ -615,6 +619,150 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, Suppo return self.transformer.vision(pixel_values) + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + second_per_grid_ts: Optional[list[float]] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for GLM4V.""" + + image_token_id = hf_config.image_token_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1] + ): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + for t_idx in range(llm_grid_t): + t_index = ( + torch.tensor(t_idx) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(1, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(1, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return llm_positions, mrope_position_delta + def get_language_model(self) -> torch.nn.Module: return self.transformer diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 578436fcad21..21d8099b43d1 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -38,7 +38,7 @@ from vllm.multimodal.processing import ( ) from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .interfaces import SupportsLoRA, SupportsMRoPE, SupportsMultiModal, SupportsPP from .keye import ( BaseKeyeModule, BaseMultiModalProcessor, @@ -493,7 +493,7 @@ class KeyeVL1_5DummyInputsBuilder( dummy_inputs=KeyeVL1_5DummyInputsBuilder, ) class KeyeVL1_5ForConditionalGeneration( - BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP + BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): def _build_projector( self, @@ -589,3 +589,143 @@ class KeyeVL1_5ForConditionalGeneration( end = patch_cu_seqlens[idx + 1] new_video_embeds.append(video_embeds[start:end]) return tuple(new_video_embeds) + + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + second_per_grid_ts: Optional[list[float]] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: + video_grid_thw = video_grid_thw[0] + """Get mrope input positions and delta value (Keye series).""" + + def split_thw(grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]: + """ + Split grid_thw along the t dimension. + + Args: + grid_thw: shape [N, 3] tensor or nested list of [t, h, w]. + + Returns: + List of [1, h, w] rows, repeated t times for each original row. + """ + + if isinstance(grid_thw, list): + grid_thw = torch.tensor(grid_thw, dtype=torch.long) + + if grid_thw.numel() == 0: + return [] + + t, hw = grid_thw[:, 0], grid_thw[:, 1:] + ones = torch.ones_like(hw[:, :1]) # [N,1] + out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) + return out.tolist() + + video_grid_thw = split_thw(video_grid_thw) + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + image_nums = len(image_grid_thw) + frame_nums = len(video_grid_thw) + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_frames = image_nums, frame_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + frame_nums): + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_frames > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_frames -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + ) + .long() + .flatten() + ) + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 1ab2f43c9d73..0df79fc733f3 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -29,6 +29,7 @@ from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn +from transformers import PretrainedConfig from transformers.feature_extraction_utils import BatchFeature from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( Qwen2_5OmniConfig, @@ -45,7 +46,6 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionTransformer, @@ -93,6 +93,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsPP, ) @@ -101,7 +102,9 @@ from .utils import ( WeightsMapper, init_vllm_registered_model, maybe_prefix, + split_list_into_ranges, ) +from .vision import get_llm_pos_ids_for_vision try: import flash_attn @@ -412,6 +415,59 @@ class Qwen2_5OmniThinkerMultiModalProcessor( return prompt_ids, mm_placeholders + @classmethod + def omni_get_updates_use_audio_in_video( + cls, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: Union[list[int], torch.Tensor], + video_second_per_grid_t: float, + ) -> list[int]: + """Get video prompt updates when `use_audio_in_video` is True. + + In this case, audio and vision update ids will be split into + chunks and interleaved (details in `_omni_get_input_positions_tensor`). + + <|video_bos|><|VIDEO|><|video_eos|> => + <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> + """ + + audio_token_id = thinker_config.audio_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr( + thinker_config.vision_config, "tokens_per_second", 25 + ) + + grid_t = video_grid_thw[0] + grid_h = video_grid_thw[1] + grid_w = video_grid_thw[2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = ( + torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second + ).long() + t_index_split_chunk = split_list_into_ranges(t_index, t_ntoken_per_chunk) + + updates = [audio_start_token_id] + added_audio_len = 0 + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = ( + len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + ) + updates.extend([video_token_id] * vision_ntoken_per_chunk) + + audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len) + updates.extend(audio_chunk_size * [audio_token_id]) + added_audio_len += audio_chunk_size + if added_audio_len < audio_len: + updates.extend((audio_len - added_audio_len) * [audio_token_id]) + updates.extend([audio_end_token_id]) + + return updates + def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -491,7 +547,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( else: video_second_per_grid_t = 1.0 - return MRotaryEmbedding.omni_get_updates_use_audio_in_video( + return self.omni_get_updates_use_audio_in_video( thinker_config=thinker_config, audio_len=audio_num_features, video_grid_thw=video_grid_thw, @@ -808,6 +864,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( SupportsMultiModal, SupportsPP, SupportsLoRA, + SupportsMRoPE, Qwen2_5OmniConditionalGenerationMixin, ): hf_to_vllm_mapper = WeightsMapper( @@ -929,6 +986,216 @@ class Qwen2_5OmniThinkerForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). + + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + + thinker_config = hf_config.thinker_config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr( + thinker_config.vision_config, "tokens_per_second", 25 + ) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]: + if use_audio_in_video and idx > 0: + if ( + src_item[idx] == vision_end_token_id + and src_item[idx - 1] == audio_end_token_id + ): + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif ( + src_item[idx] == audio_start_token_id + and src_item[idx - 1] == vision_start_token_id + ): + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2 + ) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) + * second_per_grid_ts[video_idx] + * tokens_per_second + ).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = ( + torch.arange(grid_t) + * second_per_grid_ts[video_idx] + * tokens_per_second + ).long() + t_index_split_chunk = split_list_into_ranges( + t_index, t_ntoken_per_chunk + ) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: list[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = ( + len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + ) + new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) + vision_llm_pos_ids_list = get_llm_pos_ids_for_vision( + start_idx, + video_idx, + spatial_merge_size, + t_chunk, + grid_hs, + grid_ws, + ).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend( + min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) + * [audio_token_id] + ) + audio_start_idx = ( + start_idx + if len(audio_llm_pos_ids_list) == 0 + else audio_llm_pos_ids_list[-1][0].item() + 1 + ) + if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = ( + torch.arange( + min( + t_ntoken_per_chunk, pure_audio_len - added_audio_len + ) + ).expand(3, -1) + + audio_start_idx + ).split(1, dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min( + t_ntoken_per_chunk, pure_audio_len - added_audio_len + ) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend( + (pure_audio_len - added_audio_len) * [audio_token_id] + ) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand(3, -1) + + llm_pos_ids_list[-1].max() + + 1 + ).split(1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = ( + torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) + ) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 9cd83f61d921..094fd90aac4e 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -34,7 +34,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import BatchFeature +from transformers import BatchFeature, PretrainedConfig from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, @@ -79,6 +79,7 @@ from .interfaces import ( MultiModalEmbeddings, SupportsEagle3, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsMultiModalPruning, SupportsPP, @@ -1053,6 +1054,7 @@ class Qwen2_5_VLForConditionalGeneration( SupportsQuant, SupportsEagle3, SupportsMultiModalPruning, + SupportsMRoPE, ): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -1073,6 +1075,132 @@ class Qwen2_5_VLForConditionalGeneration( supports_encoder_tp_data = True + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_videos > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + * video_second_per_grid_t + * tokens_per_second + ) + .long() + .flatten() + ) + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 8862e88bd531..1e6c3485c4d6 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -33,7 +33,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BatchFeature +from transformers import BatchFeature, PretrainedConfig from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( smart_resize as image_smart_resize, @@ -84,6 +84,7 @@ from vllm.utils import is_list_of from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsPP, ) @@ -1174,7 +1175,7 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM): dummy_inputs=Qwen3VLDummyInputsBuilder, ) class Qwen3VLForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): packed_modules_mapping = { "qkv_proj": [ @@ -1480,6 +1481,116 @@ class Qwen3VLForConditionalGeneration( ) return mm_input_by_modality + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + second_per_grid_ts: Optional[list[float]] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + return llm_positions, mrope_position_delta + def get_language_model(self) -> torch.nn.Module: return self.language_model diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 2a64f6865f12..bd530be73c2a 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -410,6 +410,14 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: return " + ".join(_embedding_count_expression(inner) for inner in embeddings) +def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]: + ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)] + for num in lst: + index = num // interval + ranges[index].append(num) + return ranges + + def _merge_multimodal_embeddings( inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a323835e575c..ec824f6d6bf5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -875,30 +875,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if mm_input.get("use_audio_in_video") is True: use_audio_in_video = True - if supports_mrope(self.get_model()): - req_state.mrope_positions, req_state.mrope_position_delta = ( - self.model.get_mrope_input_positions( - req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - ) - else: - req_state.mrope_positions, req_state.mrope_position_delta = ( - MRotaryEmbedding.get_input_positions_tensor( - req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + assert supports_mrope(self.get_model()), "M-RoPE support is not implemented." + + req_state.mrope_positions, req_state.mrope_position_delta = ( + self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, ) + ) def _extract_mm_kwargs( self, @@ -2900,7 +2889,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logger.info("Loading drafter model...") self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: - if not supports_eagle3(self.model): + if not supports_eagle3(self.get_model()): raise RuntimeError( "Model does not support EAGLE3 interface but " "aux_hidden_state_outputs was requested" @@ -2928,7 +2917,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): prepare_communication_buffer_for_model(self.model) self.is_multimodal_pruning_enabled = ( - supports_multimodal_pruning(self.model) + supports_multimodal_pruning(self.get_model()) and self.model_config.multimodal_config.is_multimodal_pruning_enabled() )