diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 0a13543c82e1..ebfe9257c6c4 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -426,7 +426,7 @@ class MRotaryEmbedding(RotaryEmbedding): ) -> tuple[torch.Tensor, int]: from vllm.transformers_utils.config import thinker_uses_mrope - if thinker_uses_mrope(hf_config): + if thinker_uses_mrope(hf_config) and hf_config.model_type == "qwen2_5_omni": return cls._omni_get_input_positions_tensor( input_tokens=input_tokens, hf_config=hf_config, @@ -1119,339 +1119,6 @@ class MRotaryEmbedding(RotaryEmbedding): return llm_positions, mrope_position_delta - @classmethod - def _omni3_get_input_positions_tensor( - cls, - config, - input_ids: torch.Tensor, - image_grid_thw: torch.Tensor, - video_grid_thw: torch.Tensor, - use_audio_in_video: bool = False, - audio_seqlens: Optional[torch.Tensor] = None, - second_per_grids: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor): - input_lengths_leave = input_lengths % 100 - feat_lengths = (input_lengths_leave - 1) // 2 + 1 - output_lengths = ( - ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 - ) - return output_lengths - - if input_ids is None or input_ids.ndim != 1: - raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") - - seq_len = input_ids.shape[0] - device = input_ids.device - dtype = input_ids.dtype - - if image_grid_thw is not None: - image_grid_thw = image_grid_thw.to(device=device, dtype=torch.long) - if video_grid_thw is not None: - video_grid_thw = video_grid_thw.to(device=device, dtype=torch.long) - - if second_per_grids is None: - if video_grid_thw is not None and video_grid_thw.numel() > 0: - second_per_grids = torch.ones( - video_grid_thw.shape[0], dtype=torch.float32, device=device - ) - else: - second_per_grids = torch.tensor([], dtype=torch.float32, device=device) - else: - second_per_grids = second_per_grids.to(device=device, dtype=torch.float32) - - if audio_seqlens is not None: - audio_seqlens = audio_seqlens.to(device=device, dtype=torch.long) - - spatial_merge_size = config.vision_config.spatial_merge_size - image_token_id = config.image_token_id - video_token_id = config.video_token_id - audio_token_id = config.audio_token_id - vision_start_token_id = config.vision_start_token_id - audio_start_token_id = config.audio_start_token_id - position_id_per_seconds = config.position_id_per_seconds - - vision_start_indices = torch.argwhere( - input_ids == vision_start_token_id - ).squeeze(1) - if vision_start_indices.numel() > 0: - vision_tokens = input_ids[vision_start_indices + 1] - else: - vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype) - audio_nums = torch.sum(input_ids == audio_start_token_id) - image_nums = (vision_tokens == image_token_id).sum() - video_nums = ( - (vision_tokens == audio_start_token_id).sum() - if use_audio_in_video - else (vision_tokens == video_token_id).sum() - ) - - input_tokens = input_ids.tolist() - llm_pos_ids_list: list[torch.Tensor] = [] - st = 0 - image_idx = 0 - video_idx = 0 - audio_idx = 0 - remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501 - multimodal_nums = ( - image_nums + audio_nums - if use_audio_in_video - else image_nums + video_nums + audio_nums - ) # noqa: E501 - - for _ in range(multimodal_nums): - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - if (image_token_id in input_tokens or video_token_id in input_tokens) and ( - remain_videos > 0 or remain_images > 0 - ): - ed_vision_start = input_tokens.index(vision_start_token_id, st) - else: - ed_vision_start = len(input_tokens) + 1 - if audio_token_id in input_tokens and remain_audios > 0: - ed_audio_start = input_tokens.index(audio_start_token_id, st) - else: - ed_audio_start = len(input_tokens) + 1 - min_ed = min(ed_vision_start, ed_audio_start) - - if min_ed == ed_audio_start: - text_len = min_ed - st - if text_len != 0: - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - llm_pos_ids_list.append( - torch.arange(text_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - bos_len = 1 - llm_pos_ids_list.append( - torch.arange(bos_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) - llm_pos_ids = ( - torch.arange(audio_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - llm_pos_ids_list.append(llm_pos_ids) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - eos_len = 1 - llm_pos_ids_list.append( - torch.arange(eos_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - st += text_len + bos_len + audio_len + eos_len - audio_idx += 1 - remain_audios -= 1 - elif ( - min_ed == ed_vision_start - and input_ids[ed_vision_start + 1] == image_token_id - ): - text_len = min_ed - st - if text_len != 0: - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - llm_pos_ids_list.append( - torch.arange(text_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - bos_len = 1 - llm_pos_ids_list.append( - torch.arange(bos_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - grid_t = image_grid_thw[image_idx][0] - grid_hs = image_grid_thw[:, 1] - grid_ws = image_grid_thw[:, 2] - t_index = torch.arange(grid_t, device=device) * position_id_per_seconds - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) - llm_pos_ids_list.append(llm_pos_ids) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - eos_len = 1 - llm_pos_ids_list.append( - torch.arange(eos_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - st += text_len + bos_len + image_len + eos_len - image_idx += 1 - remain_images -= 1 - elif ( - min_ed == ed_vision_start - and input_ids[ed_vision_start + 1] == video_token_id - and not use_audio_in_video - ): - text_len = min_ed - st - if text_len != 0: - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - llm_pos_ids_list.append( - torch.arange(text_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - bos_len = 1 - llm_pos_ids_list.append( - torch.arange(bos_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - grid_t = video_grid_thw[video_idx][0] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_index = ( - torch.arange(grid_t, device=device) - * float(second_per_grids[video_idx].item()) - * position_id_per_seconds - ) - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) - llm_pos_ids_list.append(llm_pos_ids) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - eos_len = 1 - llm_pos_ids_list.append( - torch.arange(eos_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - st += text_len + bos_len + video_len + eos_len - video_idx += 1 - remain_videos -= 1 - elif ( - min_ed == ed_vision_start - and ed_vision_start + 1 == ed_audio_start - and use_audio_in_video - ): - text_len = min_ed - st - if text_len != 0: - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - llm_pos_ids_list.append( - torch.arange(text_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - bos_len = 1 - bos_block = ( - torch.arange(bos_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - llm_pos_ids_list.append(bos_block) - llm_pos_ids_list.append(bos_block) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) - audio_llm_pos_ids = ( - torch.arange(audio_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - grid_t = video_grid_thw[video_idx][0] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_index = ( - torch.arange(grid_t, device=device) - * float(second_per_grids[video_idx].item()) - * position_id_per_seconds - ) - video_llm_pos_ids = cls._get_llm_pos_ids_for_vision( - st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - video_data_index, audio_data_index = 0, 0 - while ( - video_data_index < video_llm_pos_ids.shape[-1] - and audio_data_index < audio_llm_pos_ids.shape[-1] - ): - if ( - video_llm_pos_ids[0][video_data_index] - <= audio_llm_pos_ids[0][audio_data_index] - ): - llm_pos_ids_list.append( - video_llm_pos_ids[ - :, video_data_index : video_data_index + 1 - ] - ) - video_data_index += 1 - else: - llm_pos_ids_list.append( - audio_llm_pos_ids[ - :, audio_data_index : audio_data_index + 1 - ] - ) - audio_data_index += 1 - if video_data_index < video_llm_pos_ids.shape[-1]: - llm_pos_ids_list.append( - video_llm_pos_ids[ - :, video_data_index : video_llm_pos_ids.shape[-1] - ] - ) - if audio_data_index < audio_llm_pos_ids.shape[-1]: - llm_pos_ids_list.append( - audio_llm_pos_ids[ - :, audio_data_index : audio_llm_pos_ids.shape[-1] - ] - ) - video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - eos_len = 1 - eos_block = ( - torch.arange(eos_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - llm_pos_ids_list.append(eos_block) - llm_pos_ids_list.append(eos_block) - st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501 - audio_idx += 1 - video_idx += 1 - remain_videos -= 1 - remain_audios -= 1 - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len, device=device, dtype=torch.long) - .view(1, -1) - .expand(3, -1) - + st_idx - ) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - if llm_positions.shape[1] != seq_len: - raise RuntimeError("Position ids length mismatch with input ids length") - - position_ids = llm_positions.to(device=device, dtype=dtype) - mrope_position_delta = llm_positions.max() + 1 - seq_len - return position_ids, mrope_position_delta - @classmethod def _omni_get_input_positions_tensor( cls, @@ -1483,8 +1150,6 @@ class MRotaryEmbedding(RotaryEmbedding): # TODO(fyabc): refactor and share more code with # _vl_get_input_positions_tensor. - - model_type = hf_config.model_type thinker_config = hf_config.thinker_config if isinstance(image_grid_thw, list): @@ -1492,30 +1157,6 @@ class MRotaryEmbedding(RotaryEmbedding): if isinstance(video_grid_thw, list): video_grid_thw = torch.tensor(video_grid_thw) - if "qwen3_omni" in model_type: - input_tensor = torch.tensor(input_tokens) - audio_lengths_tensor = audio_feature_lengths - if audio_lengths_tensor is not None and not isinstance( - audio_lengths_tensor, torch.Tensor - ): - audio_lengths_tensor = torch.as_tensor( - audio_lengths_tensor, dtype=torch.long - ) - second_per_grids_tensor = ( - torch.tensor(second_per_grid_ts) if second_per_grid_ts else None - ) - - llm_positions, mrope_position_delta = cls._omni3_get_input_positions_tensor( # noqa: E501 - thinker_config, - input_tensor, - image_grid_thw, - video_grid_thw, - use_audio_in_video, - audio_lengths_tensor, - second_per_grids_tensor, - ) - return llm_positions, mrope_position_delta - audio_token_id = thinker_config.audio_token_index image_token_id = thinker_config.image_token_index video_token_id = thinker_config.video_token_index diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 8a5aa9c2be3b..6eb9faabd1c7 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -72,7 +72,12 @@ from vllm.multimodal.processing import ( ) from vllm.sequence import IntermediateTensors -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import ( + MultiModalEmbeddings, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) # yapf conflicts with isort for this block # yapf: disable @@ -96,7 +101,7 @@ from .utils import ( _merge_multimodal_embeddings, maybe_prefix, ) -from .vision import get_vit_attn_backend +from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend try: import flash_attn @@ -106,6 +111,15 @@ except (ImportError, ModuleNotFoundError): logger = init_logger(__name__) +def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ( + ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + ) + return feat_lengths, output_lengths + + class Qwen3_VisionPatchEmbed(nn.Module): def __init__( self, @@ -679,16 +693,6 @@ Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder class Qwen3OmniMoeThinkerMultiModalProcessor( Qwen2_5OmniThinkerMultiModalProcessor, ): - def _get_feat_extract_output_lengths( - self, input_lengths: torch.Tensor - ) -> torch.Tensor: - input_lengths_leave = input_lengths % 100 - feat_lengths = (input_lengths_leave - 1) // 2 + 1 - output_lengths = ( - ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 - ) - return feat_lengths, output_lengths - def _call_hf_processor( self, prompt: str, @@ -882,13 +886,13 @@ class Qwen3OmniMoeThinkerMultiModalProcessor( if audio_feature_lengths is None and feature_attention_mask is None: audio_output_lengths = [] elif audio_feature_lengths is not None: - _, audio_output_lens = self._get_feat_extract_output_lengths( + _, audio_output_lens = _get_feat_extract_output_lengths( audio_feature_lengths ) audio_output_lengths = audio_output_lens.tolist() elif feature_attention_mask is not None: assert isinstance(feature_attention_mask, torch.Tensor) - _, audio_output_lens = self._get_feat_extract_output_lengths( + _, audio_output_lens = _get_feat_extract_output_lengths( feature_attention_mask.sum(-1) ) audio_output_lengths = audio_output_lens.tolist() @@ -1044,16 +1048,6 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix else: return torch.concat(mm_input, dim=dim) - def _get_feat_extract_output_lengths( - self, input_lengths: torch.Tensor - ) -> torch.Tensor: - input_lengths_leave = input_lengths % 100 - feat_lengths = (input_lengths_leave - 1) // 2 + 1 - output_lengths = ( - ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 - ) - return output_lengths, output_lengths - def _process_audio_input( self, audio_input: Qwen2AudioFeatureInputs, @@ -1072,8 +1066,8 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix if audio_feature_lengths.ndim == 2: audio_feature_lengths = audio_feature_lengths.reshape(-1) - audio_feat_lengths, audio_output_lengths = ( - self._get_feat_extract_output_lengths(audio_feature_lengths) + audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths( + audio_feature_lengths ) audio_outputs = self.audio_tower( @@ -1094,6 +1088,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, + SupportsMRoPE, Qwen3OmniMoeConditionalGenerationMixin, ): hf_to_vllm_mapper = WeightsMapper( @@ -1407,3 +1402,311 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded_weights + + @classmethod + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + config = hf_config.thinker_config + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + input_ids = torch.tensor(input_tokens) + if input_ids is None or input_ids.ndim != 1: + raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") + + seq_len = input_ids.shape[0] + if audio_feature_lengths is not None and not isinstance( + audio_feature_lengths, torch.Tensor + ): + audio_feature_lengths = torch.as_tensor( + audio_feature_lengths, dtype=torch.long + ) + if second_per_grid_ts is None: + if video_grid_thw is not None and video_grid_thw.numel() > 0: + second_per_grids = torch.ones( + video_grid_thw.shape[0], dtype=torch.float32 + ) + else: + second_per_grids = torch.tensor([], dtype=torch.float32) + else: + second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32) + + spatial_merge_size = config.vision_config.spatial_merge_size + image_token_id = config.image_token_id + video_token_id = config.video_token_id + audio_token_id = config.audio_token_id + vision_start_token_id = config.vision_start_token_id + audio_start_token_id = config.audio_start_token_id + position_id_per_seconds = config.position_id_per_seconds + + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + if vision_start_indices.numel() > 0: + vision_tokens = input_ids[vision_start_indices + 1] + else: + vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype) + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + + llm_pos_ids_list: list[torch.Tensor] = [] + st = 0 + image_idx = 0 + video_idx = 0 + audio_idx = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501 + multimodal_nums = ( + image_nums + audio_nums + if use_audio_in_video + else image_nums + video_nums + audio_nums + ) # noqa: E501 + + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + if (image_token_id in input_tokens or video_token_id in input_tokens) and ( + remain_videos > 0 or remain_images > 0 + ): + ed_vision_start = input_tokens.index(vision_start_token_id, st) + else: + ed_vision_start = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio_start = input_tokens.index(audio_start_token_id, st) + else: + ed_audio_start = len(input_tokens) + 1 + min_ed = min(ed_vision_start, ed_audio_start) + + if min_ed == ed_audio_start: + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + _, audio_len = _get_feat_extract_output_lengths( + audio_feature_lengths[audio_idx] + ) + llm_pos_ids = ( + torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + elif ( + min_ed == ed_vision_start + and input_ids[ed_vision_start + 1] == image_token_id + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = torch.arange(grid_t) * position_id_per_seconds + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + elif ( + min_ed == ed_vision_start + and input_ids[ed_vision_start + 1] == video_token_id + and not use_audio_in_video + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) + * float(second_per_grids[video_idx].item()) + * position_id_per_seconds + ) + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + elif ( + min_ed == ed_vision_start + and ed_vision_start + 1 == ed_audio_start + and use_audio_in_video + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + bos_block = ( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(bos_block) + llm_pos_ids_list.append(bos_block) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + _, audio_len = _get_feat_extract_output_lengths( + audio_feature_lengths[audio_idx] + ) + audio_llm_pos_ids = ( + torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) + * float(second_per_grids[video_idx].item()) + * position_id_per_seconds + ) + video_llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_data_index, audio_data_index = 0, 0 + while ( + video_data_index < video_llm_pos_ids.shape[-1] + and audio_data_index < audio_llm_pos_ids.shape[-1] + ): + if ( + video_llm_pos_ids[0][video_data_index] + <= audio_llm_pos_ids[0][audio_data_index] + ): + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, video_data_index : video_data_index + 1 + ] + ) + video_data_index += 1 + else: + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, audio_data_index : audio_data_index + 1 + ] + ) + audio_data_index += 1 + if video_data_index < video_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, video_data_index : video_llm_pos_ids.shape[-1] + ] + ) + if audio_data_index < audio_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, audio_data_index : audio_llm_pos_ids.shape[-1] + ] + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + eos_block = ( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(eos_block) + llm_pos_ids_list.append(eos_block) + st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501 + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + if llm_positions.shape[1] != seq_len: + raise RuntimeError("Position ids length mismatch with input ids length") + + mrope_position_delta = llm_positions.max() + 1 - seq_len + return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 74262f8b94a6..e517109e94dd 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -499,3 +499,40 @@ def run_dp_sharded_mrope_vision_model( "Found unassigned embeddings" ) return out_embeddings + + +def get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, +) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(len(t_index), -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(len(t_index), llm_grid_h, -1) + .flatten() + ) + t_index_tensor = ( + torch.Tensor(t_index) + .to(llm_grid_h.device) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .long() + .flatten() + ) + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2dce58237c7b..a323835e575c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -875,7 +875,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if mm_input.get("use_audio_in_video") is True: use_audio_in_video = True - if supports_mrope(self.model): + if supports_mrope(self.get_model()): req_state.mrope_positions, req_state.mrope_position_delta = ( self.model.get_mrope_input_positions( req_state.prompt_token_ids,