diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d3ee6872dd8bf..4ccba64f2c110 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, - SupportsPP, SupportsTranscription, SupportsV0Only, - has_inner_state, supports_lora, supports_multimodal, - supports_pp, supports_transcription, supports_v0_only) +from .interfaces import (HasInnerState, SupportsLoRA, SupportsMRoPE, + SupportsMultiModal, SupportsPP, SupportsTranscription, + SupportsV0Only, has_inner_state, supports_lora, + supports_mrope, supports_multimodal, supports_pp, + supports_transcription, supports_v0_only) from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration, is_pooling_model, is_text_generation_model) from .registry import ModelRegistry @@ -21,6 +22,8 @@ __all__ = [ "supports_lora", "SupportsMultiModal", "supports_multimodal", + "SupportsMRoPE", + "supports_mrope", "SupportsPP", "supports_pp", "SupportsTranscription", diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 8f8e300c84d71..e9c600e36cfa7 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, import numpy as np import torch from torch import Tensor +from transformers import PretrainedConfig from transformers.models.whisper.tokenization_whisper import LANGUAGES from typing_extensions import Self, TypeIs @@ -852,3 +853,70 @@ def supports_eagle3( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]: return isinstance(model, SupportsEagle3) + + +@runtime_checkable +class SupportsMRoPE(Protocol): + """The interface required for all models that support M-RoPE.""" + + supports_mrope: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports M-RoPE. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """ + Get M-RoPE input positions and delta value for this specific model. + + This method should be implemented by each model that supports M-RoPE + to provide model-specific logic for computing input positions. + + Args: + input_tokens: List of input token IDs + hf_config: HuggingFace model configuration + image_grid_thw: Image grid dimensions (t, h, w) + video_grid_thw: Video grid dimensions (t, h, w) + second_per_grid_ts: Seconds per grid timestep for videos + context_len: Context length + seq_len: Sequence length + audio_feature_lengths: Audio feature lengths for multimodal models + use_audio_in_video: Whether to use audio in video for interleaving + + Returns: + Tuple of (llm_positions, mrope_position_delta) + - llm_positions: Tensor of shape [3, num_tokens] + with T/H/W positions + - mrope_position_delta: Delta for position calculations + """ + ... + + +@overload +def supports_mrope(model: type[object]) -> TypeIs[type[SupportsMRoPE]]: + ... + + +@overload +def supports_mrope(model: object) -> TypeIs[SupportsMRoPE]: + ... + + +def supports_mrope( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsMRoPE]], TypeIs[SupportsMRoPE]]: + return isinstance(model, SupportsMRoPE) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b6576b783b64a..7f361678ba72e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -32,7 +32,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import AutoConfig, BatchFeature +from transformers import AutoConfig, BatchFeature, PretrainedConfig from transformers.models.qwen2_vl import (Qwen2VLImageProcessor, Qwen2VLProcessor) from transformers.models.qwen2_vl.configuration_qwen2_vl import ( @@ -73,7 +73,7 @@ from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, @@ -1096,7 +1096,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] info=Qwen2VLProcessingInfo, dummy_inputs=Qwen2VLDummyInputsBuilder) class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): + SupportsLoRA, SupportsPP, SupportsMRoPE): # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( @@ -1109,6 +1109,118 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, "model.": "language_model.model.", }) + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get M-RoPE input positions for Qwen2-VL model.""" + if image_grid_thw is None: + image_grid_thw = [] + if video_grid_thw is None: + video_grid_thw = [] + if second_per_grid_ts is None: + second_per_grid_ts = [] + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, + "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_videos > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * + tokens_per_second).long().flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4873b586724ec..053e8f0537ed9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -42,6 +42,7 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (is_mixture_of_experts, supports_eagle3, + supports_mrope, supports_transcription) from vllm.model_executor.models.interfaces_base import ( VllmModelForPooling, is_pooling_model, is_text_generation_model) @@ -730,16 +731,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if mm_input.get("use_audio_in_video") is True: use_audio_in_video = True - req_state.mrope_positions, req_state.mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + if supports_mrope(self.model): + req_state.mrope_positions, req_state.mrope_position_delta = \ + self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + else: + req_state.mrope_positions, req_state.mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) def _extract_mm_kwargs( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 88f83c9dd7e6c..594382650e3c1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -41,7 +41,8 @@ from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, get_sampler) from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.model_executor.models import (supports_lora, supports_mrope, + supports_multimodal) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, @@ -670,18 +671,33 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): inter_data.seq_ids[seq_idx]] token_ids = seq_data.get_token_ids() - mrope_input_positions, mrope_position_delta = \ - MRotaryEmbedding.get_input_positions( - token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=inter_data.context_lens[seq_idx], - seq_len=inter_data.seq_lens[seq_idx], - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + if supports_mrope(self.runner.model): + mrope_input_positions, mrope_position_delta = \ + self.runner.model.get_mrope_input_positions( + token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=inter_data.context_lens[seq_idx], + seq_len=inter_data.seq_lens[seq_idx], + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + mrope_input_positions = mrope_input_positions.tolist() + else: + mrope_input_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=inter_data.context_lens[seq_idx], + seq_len=inter_data.seq_lens[seq_idx], + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) seq_data.mrope_position_delta = mrope_position_delta inter_data.mrope_input_positions[