From 0b8166aa8f02206e0c4ad0a295b86335324c6ecf Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 27 Sep 2025 16:15:12 +0800 Subject: [PATCH] [Bugfix] Merge MM embeddings by index instead of token IDs (#16229) Signed-off-by: DarkLight1337 Signed-off-by: NickLucche Signed-off-by: Roger Wang Co-authored-by: NickLucche Co-authored-by: Roger Wang Signed-off-by: yewentao256 --- docs/contributing/model/multimodal.md | 33 +----- vllm/config/model.py | 7 +- vllm/model_executor/models/aria.py | 25 ++--- vllm/model_executor/models/aya_vision.py | 27 ++--- vllm/model_executor/models/bert.py | 12 +++ vllm/model_executor/models/bert_with_rope.py | 6 ++ vllm/model_executor/models/blip2.py | 22 ++-- vllm/model_executor/models/chameleon.py | 24 ++--- vllm/model_executor/models/chatglm.py | 3 + vllm/model_executor/models/cohere2_vision.py | 27 ++--- vllm/model_executor/models/deepseek_eagle.py | 6 ++ vllm/model_executor/models/deepseek_mtp.py | 6 ++ vllm/model_executor/models/deepseek_vl2.py | 25 ++--- vllm/model_executor/models/dots_ocr.py | 44 +++----- vllm/model_executor/models/ernie45_vl.py | 27 +++-- vllm/model_executor/models/ernie_mtp.py | 6 ++ vllm/model_executor/models/fuyu.py | 26 ++--- vllm/model_executor/models/gemma3_mm.py | 26 ++--- vllm/model_executor/models/gemma3n_mm.py | 23 ++-- vllm/model_executor/models/glm4_1v.py | 17 --- vllm/model_executor/models/glm4_moe_mtp.py | 6 ++ vllm/model_executor/models/glm4v.py | 35 ++---- vllm/model_executor/models/granite_speech.py | 34 +++--- vllm/model_executor/models/hunyuan_v1.py | 3 + .../models/hyperclovax_vision.py | 35 ++---- vllm/model_executor/models/idefics3.py | 31 ++---- vllm/model_executor/models/interfaces.py | 83 ++++++++++++-- vllm/model_executor/models/interfaces_base.py | 24 ++++- vllm/model_executor/models/interns1.py | 47 ++++---- vllm/model_executor/models/internvl.py | 46 ++++---- vllm/model_executor/models/keye.py | 18 ---- vllm/model_executor/models/kimi_vl.py | 29 +---- vllm/model_executor/models/lfm2.py | 3 + vllm/model_executor/models/llama4_eagle.py | 31 ++---- vllm/model_executor/models/llama_eagle.py | 6 ++ vllm/model_executor/models/llama_eagle3.py | 17 +-- vllm/model_executor/models/llava.py | 26 ++--- vllm/model_executor/models/llava_next.py | 31 +++--- .../model_executor/models/llava_next_video.py | 23 ++-- vllm/model_executor/models/llava_onevision.py | 13 --- vllm/model_executor/models/midashenglm.py | 25 ++--- vllm/model_executor/models/mimo_mtp.py | 6 ++ vllm/model_executor/models/minicpmv.py | 27 ++--- vllm/model_executor/models/minimax_text_01.py | 10 +- vllm/model_executor/models/minimax_vl_01.py | 25 ++--- vllm/model_executor/models/mistral3.py | 26 ++--- vllm/model_executor/models/mllama4.py | 30 ++---- vllm/model_executor/models/modernbert.py | 9 ++ vllm/model_executor/models/molmo.py | 32 ++---- .../model_executor/models/nano_nemotron_vl.py | 43 +++----- vllm/model_executor/models/nemotron_vl.py | 37 ++++--- vllm/model_executor/models/olmo2.py | 6 ++ vllm/model_executor/models/ovis.py | 21 +--- vllm/model_executor/models/ovis2_5.py | 18 +--- vllm/model_executor/models/paligemma.py | 23 ++-- vllm/model_executor/models/phi3v.py | 44 +++++--- vllm/model_executor/models/phi4_multimodal.py | 18 +--- vllm/model_executor/models/phi4mm.py | 14 --- vllm/model_executor/models/pixtral.py | 26 ++--- .../models/qwen2_5_omni_thinker.py | 26 ++--- vllm/model_executor/models/qwen2_5_vl.py | 13 --- vllm/model_executor/models/qwen2_audio.py | 23 ++-- vllm/model_executor/models/qwen2_vl.py | 13 --- vllm/model_executor/models/qwen3_vl.py | 89 +++++++++------ vllm/model_executor/models/qwen_vl.py | 25 ++--- vllm/model_executor/models/roberta.py | 3 + vllm/model_executor/models/skyworkr1v.py | 36 ++++--- vllm/model_executor/models/solar.py | 3 + vllm/model_executor/models/step3_text.py | 3 + vllm/model_executor/models/step3_vl.py | 52 ++++----- vllm/model_executor/models/tarsier.py | 25 ++--- vllm/model_executor/models/terratorch.py | 3 + vllm/model_executor/models/transformers.py | 62 ++++++++--- vllm/model_executor/models/ultravox.py | 36 ++++--- vllm/model_executor/models/utils.py | 102 +++++++----------- vllm/model_executor/models/voxtral.py | 29 ++--- vllm/model_executor/models/whisper.py | 10 +- vllm/v1/spec_decode/eagle.py | 46 ++++---- vllm/v1/worker/gpu_model_runner.py | 52 ++++++--- vllm/v1/worker/tpu_model_runner.py | 81 +++++++++++--- 80 files changed, 966 insertions(+), 1139 deletions(-) diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 87d34d207cde3..1d72fe97b9665 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -66,35 +66,12 @@ Further update the model as follows: !!! important The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request. -- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings. +!!! note + By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in + [PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing. + This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings]. - ??? code - - ```python - from .utils import merge_multimodal_embeddings - - class YourModelForImage2Seq(nn.Module): - ... - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - - # `get_input_embeddings` should already be implemented for the language - # model as one of the requirements of basic vLLM model implementation. - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_index) - - return inputs_embeds - ``` + You may override this method if additional logic is required for your model when merging embeddings. - Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model. diff --git a/vllm/config/model.py b/vllm/config/model.py index b2b68abd2c1d3..3fb448ebbf364 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -509,9 +509,14 @@ class ModelConfig: else: # task == "auto" pass else: + debug_info = { + "architectures": architectures, + "is_generative_model": is_generative_model, + "is_pooling_model": is_pooling_model, + } raise AssertionError("The model should be a generative or " "pooling model when task is set to " - f"{self.task!r}.") + f"{self.task!r}. Found: {debug_info}") self.runner = runner self.convert = convert diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 35c1adbdd00b6..6cef5e134a4bc 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -38,8 +38,7 @@ from .idefics2_vision_model import ( from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - is_pp_missing_parameter, maybe_prefix, - merge_multimodal_embeddings) + is_pp_missing_parameter, maybe_prefix) class AriaImagePixelInputs(TensorSchema): @@ -605,19 +604,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): multimodal_embeddings = self._process_image_input(image_input) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -628,10 +614,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ) -> Union[torch.Tensor, IntermediateTensors]: if inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model( diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 6fd8c2fb5c561..eab996e9ba22b 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) class AyaVisionImagePixelInputs(TensorSchema): @@ -417,23 +416,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, return self._process_image_input(image_input, **kwargs) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_index, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -449,8 +431,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model( diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index ee32587f6b1b4..c984845204c4f 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -348,6 +348,9 @@ class BertModel(nn.Module, SupportsQuant): self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -457,6 +460,9 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): prefix=maybe_prefix(prefix, "model")) self.pooler = self._build_pooler(pooler_config) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -588,6 +594,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, ), }) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.bert.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loaded_params = loader.load_weights(weights) @@ -637,6 +646,9 @@ class BertForTokenClassification(nn.Module): Pooler.for_encode(pooler_config), }) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.bert.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loaded_params = loader.load_weights(weights) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index bfc1408ddf880..4e1eba32d2594 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -426,6 +426,9 @@ class BertWithRope(nn.Module, SupportsQuant): prefix=f"{prefix}.encoder") self.pooler = BertPooler(self.config) if add_pooling_layer else None + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -673,6 +676,9 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): loaded_params = loader.load_weights(weights) return loaded_params + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.new.get_input_embeddings(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index b7455fba62c02..4d1850d07b28e 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -27,7 +27,7 @@ from .blip import BlipVisionModel from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, SupportsQuant) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) # We use this internally as placeholders since there is no image token # defined on the HuggingFace repo @@ -631,19 +631,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - _IMAGE_TOKEN_ID) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -689,8 +676,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == _IMAGE_TOKEN_ID, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 79d648d749c6a..f9740adb151b5 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -44,7 +44,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, SupportsQuant) from .utils import (flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) logger = init_logger(__name__) @@ -1002,20 +1002,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, vision_embeddings = self.model.get_input_embeddings(image_tokens) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.model.vocabulary_mapping.image_token_id) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1032,8 +1018,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + image_token_id = self.model.vocabulary_mapping.image_token_id + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == image_token_id, + ) input_ids = None hidden_states = self.model(input_ids, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 879508400222f..c182201fe2567 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -433,6 +433,9 @@ class ChatGLMBaseModel(nn.Module): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 6d67eb68d51a8..99edcba4d874a 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -37,8 +37,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) class Cohere2VisionImagePixelInputs(TensorSchema): @@ -430,23 +429,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, return self._process_image_input(image_input, **kwargs) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_id, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -462,8 +444,11 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None hidden_states = self.language_model.model( diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index ed7e7614800fc..c42a66d86912c 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -66,6 +66,9 @@ class DeepseekV2Model(nn.Module): self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -205,6 +208,9 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): self.logits_processor = LogitsProcessor(self.config.vocab_size, scale=logit_scale) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 92f311ab465b5..a4623ff13cec4 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -101,6 +101,9 @@ class DeepSeekMultiTokenPredictor(nn.Module): ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -142,6 +145,9 @@ class DeepSeekMTP(nn.Module, SupportsPP): prefix=maybe_prefix( prefix, "model")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index c8ed759d2e972..b98008c83bdcc 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -41,8 +41,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) # The image token id may be various _IMAGE_TOKEN = "" @@ -346,7 +345,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): model_config = vllm_config.model_config tokenizer = cached_tokenizer_from_config(model_config) - self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] + self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN] self.vision = self._init_vision_module(self.vision_config, quant_config, @@ -605,19 +604,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_token_id) - return inputs_embeds - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -632,8 +618,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.image_token_id, + ) input_ids = None hidden_states = self.language_model(input_ids, diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 2db350c892ae7..4845f19bcbc42 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -34,8 +34,7 @@ from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder, Qwen2VLProcessingInfo) from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, - maybe_prefix, - merge_multimodal_embeddings) + maybe_prefix) from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict @@ -796,33 +795,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_id, - ) - - return inputs_embeds - def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -830,17 +813,14 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ) -> Union[torch.Tensor, IntermediateTensors]: if intermediate_tensors is not None: inputs_embeds = None - elif inputs_embeds is None and kwargs.get("pixel_values") is not None: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: - inputs_embeds = None - else: - assert input_ids is not None - inputs_embeds = self.get_multimodal_embeddings( - input_ids, - image_input=image_input, - ) - input_ids = None + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) + input_ids = None hidden_states = self.language_model( input_ids=input_ids, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 74b358034ef3d..a73ec4f88ffe4 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -60,8 +60,7 @@ from vllm.sequence import IntermediateTensors from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .vision import get_vit_attn_backend logger = init_logger(__name__) @@ -1467,18 +1466,24 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: + self._set_visual_token_mask(input_ids) - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - if multimodal_embeddings is None: - return inputs_embeds - - self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds, - multimodal_embeddings, - [self.config.im_patch_id]) - return inputs_embeds + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index 288fbe736c32f..3b24bf2f1ef8f 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -116,6 +116,9 @@ class ErnieMultiTokenPredictor(nn.Module): ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -160,6 +163,9 @@ class ErnieMTP(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 53e9e6fe6e460..b99fe33a1dcce 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -42,8 +42,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -342,22 +341,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - _IMAGE_TOKEN_ID, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -373,8 +356,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == _IMAGE_TOKEN_ID, + ) input_ids = None hidden_states = self.language_model( diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 0630ee07c347e..be75e36fe23b5 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -37,8 +37,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) logger = init_logger(__name__) @@ -588,22 +587,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -618,8 +601,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) if (vision_embeddings is not None) and len(vision_embeddings) != 0: kwargs = self.prepare_attn_masks( input_ids, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 2acdba54a257d..b23437a08e5ab 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -632,8 +632,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache # them here, as the model forward has only access to the input_embeds. if input_ids is not None: @@ -645,15 +647,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_( per_layer_inputs) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - # NOTE: this order of processing mm items is important - [self.config.image_token_id, self.config.audio_token_id]) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward(self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index b088e0c0dd241..dbb5431ae4919 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -1552,23 +1552,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if (multimodal_embeddings is not None - and len(multimodal_embeddings) != 0 - and all(embed.numel() > 0 for embed in multimodal_embeddings)): - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id], - ) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index c572978e62206..826d541e571bd 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -132,6 +132,9 @@ class Glm4MoeMultiTokenPredictor(nn.Module): ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -173,6 +176,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP): prefix=maybe_prefix( prefix, "model")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index bf33575859aea..ace9c05daf15a 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .chatglm import ChatGLMBaseModel, ChatGLMModel from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import flatten_bn, merge_multimodal_embeddings +from .utils import flatten_bn, isin_list class GLMVImagePixelInputs(TensorSchema): @@ -607,28 +607,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.transformer.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=[ - self.config.boi_token_id, - self.config.pad_token_id, - self.config.eoi_token_id, - ], - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -644,8 +622,15 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, [ + self.config.boi_token_id, + self.config.pad_token_id, + self.config.eoi_token_id, + ]), + ) input_ids = None hidden_states = self.transformer(input_ids, positions, diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index a5849184339b1..8a02da58ea0b9 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -52,8 +52,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .blip2 import Blip2QFormerModel from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, embed_multimodal, - init_vllm_registered_model, maybe_prefix) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix ### Audio Input @@ -720,6 +719,9 @@ class GraniteSpeechForConditionalGeneration( # Split variable length features into a tuple return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object, @@ -728,7 +730,7 @@ class GraniteSpeechForConditionalGeneration( audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] - return None + audio_features = self._process_audio_input(audio_input) return audio_features @@ -736,19 +738,21 @@ class GraniteSpeechForConditionalGeneration( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - """Compute the merged LLM / audio embeddings.""" - if multimodal_embeddings is None \ - or len(multimodal_embeddings) == 0: - return self.language_model.get_input_embeddings(input_ids) + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - inputs_embeds = embed_multimodal( + return super().get_input_embeddings( input_ids, - self.config.audio_token_index, - self.language_model.model.get_input_embeddings, - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, ) - return inputs_embeds def forward( self, @@ -765,7 +769,11 @@ class GraniteSpeechForConditionalGeneration( # condition is for v0 compatibility. elif inputs_embeds is None: audio_embeds = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds) + inputs_embeds = self.get_input_embeddings( + input_ids, + audio_embeds, + is_multimodal=input_ids == self.config.audio_token_index, + ) input_ids = None model_output = self.language_model(input_ids, positions, diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 8a23a6b45bc70..d28c971167902 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -989,6 +989,9 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 4d39ff9ae79ee..f851688bf7bab 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -45,8 +45,8 @@ from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, init_vllm_registered_model, isin_list, + maybe_prefix) from .vision import get_vision_encoder_info EOT = "<|endofturn|>" @@ -691,7 +691,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def get_multimodal_embeddings( self, **kwargs: Unpack[HCXVisionMultimodalInputs], - ) -> Optional[MultiModalEmbeddings]: + ) -> MultiModalEmbeddings: multimodal_embeddings = list() if kwargs.get("pixel_values_images") is not None: @@ -736,26 +736,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): multimodal_embeddings.append(_multimodal_embeddings_videos) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - placeholder_token_id=[ - self.config.image_token_id, - self.config.video_token_id, - ], - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -771,8 +751,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility. elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=isin_list( + input_ids, + [self.config.image_token_id, self.config.video_token_id]), + ) input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 79e130119ae83..3334ee2242531 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -52,8 +52,7 @@ from .idefics2_vision_model import ( # yapf: enable from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .llama import LlamaModel -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix class Idefics3ImagePixelInputs(TensorSchema): @@ -539,10 +538,7 @@ class Idefics3Model(nn.Module): return image_hidden_states - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.text_model.get_input_embeddings(input_ids) def forward( @@ -695,22 +691,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -726,8 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None hidden_states = self.model.text_model(input_ids, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index f13e590cd243b..d40df9b43dd43 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, MutableSequence -from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, - Union, overload, runtime_checkable) +from typing import (TYPE_CHECKING, Callable, ClassVar, Literal, Optional, + Protocol, Union, overload, runtime_checkable) import numpy as np import torch @@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.utils import supports_kw -from .interfaces_base import is_pooling_model +from .interfaces_base import VllmModel, is_pooling_model if TYPE_CHECKING: from vllm.config import VllmConfig @@ -90,7 +90,7 @@ class SupportsMultiModal(Protocol): """ ... - def get_language_model(self) -> torch.nn.Module: + def get_language_model(self) -> VllmModel: """ Returns the underlying language model used for text generation. @@ -102,17 +102,84 @@ class SupportsMultiModal(Protocol): """ ... + @overload + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + ... + + @overload + def get_input_embeddings( + self, + input_ids: Tensor, + multimodal_embeddings: MultiModalEmbeddings, + *, + is_multimodal: torch.Tensor, + handle_oov_mm_token: bool = False, + ) -> Tensor: + ... + + def _get_text_embeddings( + self, + input_ids: Tensor, + get_input_embeddings: Callable[[Tensor], Tensor], + *, + is_multimodal: Optional[Tensor], + handle_oov_mm_token: bool, + ) -> Tensor: + if handle_oov_mm_token and is_multimodal is not None: + is_text = ~is_multimodal + text_embeds = get_input_embeddings(input_ids[is_text]) + + return torch.empty( + (input_ids.shape[0], text_embeds.shape[1]), + dtype=text_embeds.dtype, + device=text_embeds.device, + ).masked_scatter_(is_text.unsqueeze_(-1), text_embeds) + + return get_input_embeddings(input_ids) + def get_input_embeddings( self, input_ids: Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[Tensor] = None, + handle_oov_mm_token: bool = False, ) -> Tensor: """ - Returns the input embeddings merged from the text embeddings from - input_ids and the multimodal embeddings generated from multimodal - kwargs. + Apply token embeddings to `input_ids`. + + If `multimodal_embeddings` is passed, scatter them into + `input_ids` according to the mask `is_multimodal`. + + In case the multi-modal token IDs exceed the vocabulary size of + the language model, you can set `handle_oov_mm_token=False` + to avoid calling the language model's `get_input_embeddings` method + on those tokens. Note however that doing so increases memory usage + as an additional buffer is needed to hold the input embeddings. """ - ... + from .utils import _merge_multimodal_embeddings + + inputs_embeds = self._get_text_embeddings( + input_ids, + self.get_language_model().get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) @runtime_checkable diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 8fdf70e35a2b8..84146db0943c6 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -41,6 +41,13 @@ class VllmModel(Protocol[T_co]): ) -> None: ... + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + """Apply token embeddings to `input_ids`.""" + ... + def forward( self, input_ids: torch.Tensor, @@ -54,6 +61,19 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool: return supports_kw(model_init, "vllm_config") +def _check_vllm_model_get_input_embeddings( + model: Union[type[object], object]) -> bool: + model_get_input_embeddings = getattr(model, "get_input_embeddings", None) + if not callable(model_get_input_embeddings): + logger.warning( + "The model (%s) is missing the `get_input_embeddings` method.", + model, + ) + return False + + return True + + def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: model_forward = getattr(model, "forward", None) if not callable(model_forward): @@ -88,7 +108,9 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]: def is_vllm_model( model: Union[type[object], object], ) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]: - return _check_vllm_model_init(model) and _check_vllm_model_forward(model) + return (_check_vllm_model_init(model) + and _check_vllm_model_get_input_embeddings(model) + and _check_vllm_model_forward(model)) @runtime_checkable diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index 197d629b906fe..545dad1a96f5e 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -40,8 +40,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, isin_list, maybe_prefix) class InternS1MultiModalProjector(nn.Module): @@ -767,24 +766,24 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [ - token_id for token_id in (self.img_context_token_id, - self.video_context_token_id) - if token_id is not None - ] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -802,9 +801,17 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, context_token_ids), + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index f4004e518e3ba..78aac85414344 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + isin_list, maybe_prefix) IMG_START = '' IMG_END = '' @@ -1339,24 +1339,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [ - token_id for token_id in (self.img_context_token_id, - self.video_context_token_id) - if token_id is not None - ] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -1374,9 +1374,17 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, context_token_ids), + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 3b6fdba225122..62a71b7b1fa85 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1450,24 +1450,6 @@ class BaseKeyeModule(nn.Module): multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - [ - self.config.image_token_id, - self.config.video_token_id, - ], - ) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 503627865c4a5..db032736f9148 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -66,7 +66,6 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model from vllm.model_executor.models.interfaces import (SupportsMultiModal, SupportsPP) from vllm.model_executor.models.moonvit import MoonVitPretrainedModel -from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -424,26 +423,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - - # `get_input_embeddings` should already be implemented for the language - # model as one of the requirements of basic vLLM model implementation. - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.media_placeholder_token_id) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -462,14 +441,12 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, if image_input is None: inputs_embeds = None else: - inputs_embeds = self.get_input_embeddings(input_ids) image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( + inputs_embeds = self.get_input_embeddings( input_ids, - inputs_embeds, image_embeds, - placeholder_token_id=self.config. - media_placeholder_token_id, + is_multimodal=input_ids == + self.config.media_placeholder_token_id, ) input_ids = None diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 53c36e4e52d81..f9def222a1ec9 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -522,6 +522,9 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index a203af53205cd..235275c0940a1 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -37,9 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, Llama4ForCausalLM) from vllm.model_executor.models.utils import extract_layer_index -from vllm.multimodal.inputs import NestedTensors -from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings +from .interfaces import SupportsMultiModal +from .utils import AutoWeightsLoader, maybe_prefix logger = init_logger(__name__) @@ -79,10 +79,7 @@ class LlamaModel(nn.Module): self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -194,6 +191,11 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): self.logits_processor = LogitsProcessor(self.config.vocab_size, scale=logit_scale) + def get_language_model(self) -> torch.nn.Module: + return self.model + + get_input_embeddings = SupportsMultiModal.get_input_embeddings # type: ignore + def forward( self, input_ids: torch.Tensor, @@ -220,20 +222,3 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): skip_prefixes=(["lm_head."]), ) loader.load_weights(map(transform, weights)) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - - return inputs_embeds diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 2ff2d54a83aa8..d6e6fd3fcfe9c 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -73,6 +73,9 @@ class LlamaModel(nn.Module): self.config.hidden_size, bias=False) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -149,6 +152,9 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): self.logits_processor = LogitsProcessor(self.config.vocab_size, scale=logit_scale) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 55b6ae6ee0e9c..34b8ea0ca5360 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -18,7 +18,6 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) @@ -144,10 +143,7 @@ class LlamaModel(nn.Module): eps=self.config.rms_norm_eps, ) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -239,6 +235,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): requires_grad=False, ) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -302,11 +301,3 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): skip_substrs=skip_substrs, ) loader.load_weights(model_weights.items()) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - return inputs_embeds diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 4d8ed95b6cc8f..6f3cfd88aee23 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -41,8 +41,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info @@ -676,22 +675,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -744,8 +727,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c9133fde14552..e132389c4f061 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -25,8 +25,8 @@ from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, LlavaDummyInputsBuilder, LlavaLikeConfig, LlavaMultiModalProjector, init_vision_tower_for_llava) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal, - flatten_bn, init_vllm_registered_model, maybe_prefix) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix) class LlavaNextImagePixelInputs(TensorSchema): @@ -474,19 +474,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - if multimodal_embeddings is None \ - or len(multimodal_embeddings) == 0: - return self.language_model.get_input_embeddings(input_ids) - - inputs_embeds = embed_multimodal( + return super().get_input_embeddings( input_ids, - self.config.image_token_index, - self.language_model.model.get_input_embeddings, - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, ) - return inputs_embeds def forward( self, @@ -549,8 +551,11 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens]. # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 610fb188d57d2..2642d8c77cf3b 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -30,8 +30,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import init_vision_tower_for_llava from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info @@ -415,19 +414,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, vision_embeddings = self._process_video_pixels(video_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.video_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -449,8 +435,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.video_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index cee9ddaf94cc4..906858f4e2f47 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -850,19 +850,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_index, self.config.video_token_index]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 82648ba668ca5..0bf04e0e7e2fa 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -54,8 +54,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.midashenglm import DashengConfig from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix _Tuple2 = Union[int, tuple[int, int], Sequence[int]] @@ -744,21 +743,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): return [] return self._process_audio_input(audio_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.decoder.get_input_embeddings(input_ids) - if multimodal_embeddings and len(multimodal_embeddings) > 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.audio_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -771,8 +755,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): inputs_embeds = None elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.audio_token_id, + ) input_ids = None return self.decoder.model(input_ids, diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index b4abe458e4771..9c1e36094c4a3 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -117,6 +117,9 @@ class MiMoMultiTokenPredictor(nn.Module): self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -158,6 +161,9 @@ class MiMoMTP(nn.Module): self.config.hidden_size, prefix=maybe_prefix(prefix, "lm_head")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index a17c4f004d75c..bffc9a0c125ea 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -71,8 +71,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, flatten_bn, isin_list, maybe_prefix # For profile run _MAX_FRAMES_PER_VIDEO = 16 @@ -1144,23 +1143,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): return self._process_multimodal_inputs(modalities) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert len(self.mm_token_ids) > 0 - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - list(self.mm_token_ids), - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1178,8 +1160,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, list(self.mm_token_ids)), + ) input_ids = None hidden_states = self.llm.model( diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index cc9a959f63313..a92890c9f7b55 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -592,10 +592,7 @@ class MiniMaxText01Model(nn.Module): dtype=torch.long) minimax_cache_tensors[:, slots_tensor, ...] = 0 - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward(self, @@ -687,10 +684,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs( batch_size) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward(self, diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index d81ac8c704e79..d41b9d3f14fe2 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -28,7 +28,7 @@ from .llava_next import LlavaNextProcessingInfo from .pixtral import PixtralHFVisionModel from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) class MiniMaxVL01ImagePixelInputs(TensorSchema): @@ -218,22 +218,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def get_language_model(self) -> torch.nn.Module: return self.language_model @@ -403,8 +387,11 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index ba6da4403ae16..31571ce962d18 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -38,8 +38,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info @@ -524,22 +523,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -592,8 +575,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 79e315f794893..3af5267928cde 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -56,8 +56,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llama4 import Llama4ForCausalLM -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix from .vision import run_dp_sharded_vision_model @@ -813,24 +812,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -846,8 +827,11 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, # this condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None return self.language_model(input_ids, positions, intermediate_tensors, diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 1d5da3139de92..e4a51b3697370 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -43,6 +43,9 @@ class ModernBertEmbeddings(nn.Module): eps=config.layer_norm_eps, bias=config.norm_bias) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.tok_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -220,6 +223,9 @@ class ModernBertModel(nn.Module): eps=config.norm_eps, bias=config.norm_bias) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.hf_to_vllm_mapper.apply(weights) @@ -333,6 +339,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): ), }) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): self_weights = [] diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 201bf83cac581..054caee9e8a4f 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -58,7 +58,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] @@ -819,10 +819,7 @@ class MolmoModel(nn.Module, SupportsQuant): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -1481,24 +1478,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert self.img_patch_id is not None - - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.img_patch_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.LongTensor, @@ -1515,8 +1494,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.img_patch_id, + ) input_ids = None hidden_states = self.model(input_ids, diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 2b68d40cf2c67..505806a15c891 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -35,8 +35,7 @@ from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM from vllm.model_executor.models.radio import RadioModel from vllm.model_executor.models.utils import (flatten_bn, init_vllm_registered_model, - maybe_prefix, - merge_multimodal_embeddings) + isin_list, maybe_prefix) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, MultiModalKwargsItems, @@ -1096,8 +1095,8 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, return modalities - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: # Validate the multimodal input keyword arguments modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if modalities is None: @@ -1121,30 +1120,6 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if (multimodal_embeddings is not None - and len(multimodal_embeddings) != 0): - context_token_ids = [ - token_id for token_id in (self.img_context_token_id, - self.video_context_token_id) - if token_id is not None - ] - assert len(context_token_ids) >= 1 - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - - return inputs_embeds - def get_language_model(self) -> torch.nn.Module: return self.language_model @@ -1163,9 +1138,17 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=isin_list(input_ids, context_token_ids), + ) input_ids = None hidden_states = self.language_model( diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 3abbff8c717d4..2627a262e9582 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -38,7 +38,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) IMG_START = '' IMG_END = '' @@ -576,20 +576,24 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [self.img_context_token_id] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -608,8 +612,11 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.img_context_token_id, + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 2e0b1fb2a13f7..e7e30ee8df0ff 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -295,6 +295,9 @@ class Olmo2Model(nn.Module): make_empty_intermediate_tensors_factory(["hidden_states"], self.config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -408,6 +411,9 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index bd525b6780e0d..8503d3f71d1c9 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -48,7 +48,6 @@ from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import merge_multimodal_embeddings # Cannot find the following number from hf config. IMAGE_TOKEN = "" @@ -501,19 +500,6 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): return image_features - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_pad_token_id) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -529,8 +515,11 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.image_pad_token_id, + ) input_ids = None # up until here we have an inputs_embeds 100% numerical identity diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index f18e38ce154d2..2ecc7bff07e07 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -585,17 +585,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - tmp = torch.concat(multimodal_embeddings, dim=0) - inputs_embeds[input_ids == self.image_pad_token_id] = tmp - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -612,8 +601,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.image_pad_token_id, + ) input_ids = None # up until here we have a inputs_embeds 100% numerical identity diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index aef5102304614..f07f444819f4c 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -26,8 +26,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import get_vision_encoder_info logger = init_logger(__name__) @@ -362,19 +361,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) - return inputs_embeds - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -388,8 +374,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index a2b201fe4228d..ea34c8d92f136 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -51,9 +51,9 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, SupportsQuant) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, WeightsMapper, + _merge_multimodal_embeddings, flatten_bn, + init_vllm_registered_model, maybe_prefix) logger = init_logger(__name__) @@ -643,14 +643,31 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_token_id) - return inputs_embeds + inputs_embeds = self._get_text_embeddings( + input_ids, + self.embed_tokens, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) def forward(self, input_ids: torch.Tensor, @@ -666,8 +683,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, # condition is for v0 compatibility elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=self.image_token_id, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index bdc831354c11f..e8b79717d75d0 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -1342,12 +1342,12 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): image_attention_mask) return image_embeds - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: - return None + return [] # The result multimodal_embeddings is tuple of tensors, with each # tensor corresponding to a multimodal data item (image or video). @@ -1371,18 +1371,6 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 47b5ad55ab2d0..15b09c7ae2bc9 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1151,7 +1151,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] - return None # The result multimodal_embeddings is tuple of tensors, with each # tensor corresponding to a multimodal data item (image or video). @@ -1175,19 +1174,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.embed_tokens(input_ids) - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 7b197844c8b63..2c04b6f0f4f90 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -50,8 +50,7 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer, from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs try: @@ -433,22 +432,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.vision_args.image_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -465,8 +448,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.vision_args.image_token_id, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 5f27230c913b4..bfa398ee43b56 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -865,24 +865,26 @@ class Qwen2_5OmniThinkerForConditionalGeneration( multimodal_embeddings += audio_embeddings return multimodal_embeddings + # TODO (ywang96): support overlapping modality embeddings so that + # `use_audio_in_video` will work on V1. def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - # TODO (ywang96): support overlapping modality embeddings so that - # `use_audio_in_video` will work on V1. - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, [ - self.config.image_token_index, - self.config.video_token_index, - self.config.audio_token_index - ]) - return inputs_embeds + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def get_multimodal_embeddings_v0( self, **kwargs: object) -> Optional[NestedTensors]: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index adb21373056c5..5b092b42205fa 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1365,19 +1365,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 762ab42e5929e..9dfa29eef5ce7 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -49,8 +49,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix # # === Audio Inputs === # @@ -438,19 +437,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, masked_audio_features = self._process_audio_input(audio_input) return masked_audio_features - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.audio_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -467,8 +453,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, # condition is for v0 compatibility. elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.audio_token_index, + ) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d4e195246bf13..8192c3ce05dd2 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1459,19 +1459,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) - return inputs_embeds - def get_input_embeddings_v0( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f1aeb99a4d373..5d0b66f91aced 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -79,7 +79,8 @@ from .qwen2_5_vl import (Qwen2_5_VisionAttention, from .qwen2_vl import Qwen2VLProcessingInfo from .qwen3 import Qwen3ForCausalLM, Qwen3Model from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - maybe_prefix, merge_multimodal_embeddings) + _merge_multimodal_embeddings, maybe_prefix, + merge_multimodal_embeddings) from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) @@ -1324,17 +1325,22 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, return multimodal_embeddings def _compute_deepstack_embeds( - self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor, - multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor: - visual_lens = [ - x.shape[0] if isinstance(x, torch.Tensor) else len(x) - for x in multimodal_embeddings - ] + self, + inputs_embeds: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + is_multimodal: torch.Tensor, + ) -> tuple[torch.Tensor, MultiModalEmbeddings]: + visual_lens = [len(x) for x in multimodal_embeddings] multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) - multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501 - multimodal_embeddings_cat, [self.visual_dim, self.multiscale_dim], - dim=-1) + ( + multimodal_embeddings_main, + multimodal_embeddings_multiscale, + ) = torch.split( + multimodal_embeddings_cat, + [self.visual_dim, self.multiscale_dim], + dim=-1, + ) multimodal_embeddings = torch.split(multimodal_embeddings_main, visual_lens, @@ -1346,39 +1352,62 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1)) - deepstack_input_embeds = merge_multimodal_embeddings( - input_ids, - deepstack_input_embeds, - multimodal_embeddings_multiscale, - placeholder_token_id=[ - self.config.image_token_id, self.config.video_token_id - ], + deepstack_input_embeds = _merge_multimodal_embeddings( + inputs_embeds=deepstack_input_embeds, + multimodal_embeddings=multimodal_embeddings_multiscale, + is_multimodal=is_multimodal, ) deepstack_input_embeds = deepstack_input_embeds.view( inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim) deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2) + return deepstack_input_embeds, multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - deepstack_input_embeds = None - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - if self.use_deepstack: - deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501 - input_ids, inputs_embeds, multimodal_embeddings) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) + inputs_embeds = self._get_text_embeddings( + input_ids, + self.language_model.get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") if self.use_deepstack: - if deepstack_input_embeds is None: - deepstack_input_embeds = torch.zeros_like( - inputs_embeds).unsqueeze(0).repeat( - self.deepstack_num_level, 1, 1).contiguous() + ( + deepstack_input_embeds, + multimodal_embeddings, + ) = self._compute_deepstack_embeds( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + else: + deepstack_input_embeds = None + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + if deepstack_input_embeds is not None: + deepstack_input_embeds = torch.zeros_like(inputs_embeds).unsqueeze( + 0).repeat(self.deepstack_num_level, 1, 1).contiguous() self._set_deepstack_input_embeds(deepstack_input_embeds) return inputs_embeds diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 90200f319464b..dc11b60604a91 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -45,7 +45,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .qwen import QWenBaseModel, QWenModel -from .utils import flatten_bn, merge_multimodal_embeddings +from .utils import flatten_bn class QwenImagePixelInputs(TensorSchema): @@ -756,21 +756,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.transformer.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.transformer.visual.image_pad_id) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -786,8 +771,12 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == + self.transformer.visual.image_pad_id, + ) input_ids = None hidden_states = self.transformer(input_ids, positions, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index ba405be416876..53e698c4fa806 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -218,6 +218,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.roberta.get_input_embeddings(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 893ce4497c319..f9a107c06085b 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -38,7 +38,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) IMG_START = '' IMG_END = '' @@ -842,19 +842,24 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert self.img_context_token_id is not None + if multimodal_embeddings is not None and len( + multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.img_context_token_id, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -873,8 +878,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.img_context_token_id, + ) input_ids = None forward_kwargs = { diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index c774171b9dcd2..c5b82b0ca4a07 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -483,6 +483,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 0cce0c78f8dc6..0fe723d594839 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -395,6 +395,9 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 5f6ad58850439..ad295ef447325 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) from .vision import run_dp_sharded_vision_model @@ -996,10 +995,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, 1 else cur_feature[0]) return merged_image_features - def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings @@ -1007,24 +1009,21 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - if multimodal_embeddings is None: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - else: - is_text = input_ids != self.config.image_token_id - text_ids = input_ids[is_text] - text_embeds = self.language_model.model.get_input_embeddings( - text_ids) - inputs_embeds = torch.empty(input_ids.shape[0], - text_embeds.shape[-1], - dtype=text_embeds.dtype, - device=text_embeds.device) - inputs_embeds[is_text] = text_embeds - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_id) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -1038,10 +1037,11 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None hidden_states = self.language_model(input_ids, diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 3660efdc079aa..1145bea414808 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -40,7 +40,7 @@ from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) + maybe_prefix) from .vision import VisionEncoderInfo, get_vision_encoder_info @@ -589,22 +589,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, return [] return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -617,8 +601,11 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index b9dfa8e9b6f51..938b02e3e04b3 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -233,6 +233,9 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: # We do not really use any input tokens and therefore no embeddings # to be calculated. However, due to the mandatory token ids in diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 19dd242f16eb6..3d7b06633f342 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -52,8 +52,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of -from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, - SupportsQuant) +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP, SupportsQuant) from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, flatten_bn, make_empty_intermediate_tensors_factory, maybe_prefix) @@ -797,6 +797,9 @@ class TransformersForCausalLM(TransformersBase): else: self.lm_head = PPMissingLayer() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings()(input_ids) + def compute_logits( self, hidden_states: torch.Tensor, @@ -873,13 +876,19 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): multimodal_embeds = self.get_multimodal_embeddings(**kwargs) if multimodal_embeds is not None: inputs_embeds = self.get_input_embeddings( - input_ids, multimodal_embeds) + input_ids, + multimodal_embeds, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None model_output = super().forward(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output + def get_language_model(self) -> torch.nn.Module: + return self.model + def get_multimodal_embeddings(self, **kwargs): pixel_values = kwargs.pop("pixel_values", None) pixel_values = pixel_values if pixel_values is not None else kwargs.pop( @@ -934,15 +943,42 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings=None, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings()(input_ids) - if (multimodal_embeddings is not None - and len(multimodal_embeddings) != 0): - mask = (input_ids == self.config.image_token_id) - mask = mask.unsqueeze(-1).expand_as(inputs_embeds) - multimodal_embeddings = torch.cat(multimodal_embeddings) + """ + Apply token embeddings to `input_ids`. - inputs_embeds = inputs_embeds.masked_scatter( - mask, multimodal_embeddings) - return inputs_embeds + If `multimodal_embeddings` is passed, scatter them into + `input_ids` according to the mask `is_multimodal`. + + In case the multi-modal token IDs exceed the vocabulary size of + the language model, you can set `handle_oov_mm_token=False` + to avoid calling the language model's `get_input_embeddings` method + on those tokens. + """ + from .utils import _merge_multimodal_embeddings + + inputs_embeds = self._get_text_embeddings( + input_ids, + self.model.get_input_embeddings(), + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 12ae9487ad9dc..77e886c22e634 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix) _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" _MAX_ENCODER_BATCH_SIZE = 16 @@ -555,19 +554,21 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - # The audio token index is not included in the embedding table - # We need to remove it before embedding lookup - safe_input_ids = input_ids.clone() - safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0 - inputs_embeds = self.language_model.get_input_embeddings( - safe_input_ids) - if multimodal_embeddings is not None and len( - multimodal_embeddings) > 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.audio_token_index) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward(self, input_ids: torch.Tensor, @@ -601,8 +602,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.audio_token_index, + ) input_ids = None language_model = self.language_model diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 51cd41c864f09..7b3f20c6b28a1 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -4,7 +4,7 @@ import itertools from collections.abc import Iterable, Mapping from dataclasses import dataclass, field -from typing import Any, Callable, Literal, Optional, Protocol, Union, overload +from typing import Any, Literal, Optional, Protocol, Union, overload import torch import torch.nn as nn @@ -391,8 +391,8 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: def _merge_multimodal_embeddings( inputs_embeds: torch.Tensor, - is_multimodal: torch.Tensor, multimodal_embeddings: NestedTensors, + is_multimodal: torch.Tensor, ) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the @@ -402,63 +402,37 @@ def _merge_multimodal_embeddings( Note: This updates ``inputs_embeds`` in place. """ - flattened = _flatten_embeddings(multimodal_embeddings) - try: - # This is equivalent to: inputs_embeds[is_multimodal] = flattened. - inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), - flattened.to(dtype=inputs_embeds.dtype)) - except RuntimeError as e: - num_expected_tokens = is_multimodal.sum().item() - assert isinstance(num_expected_tokens, int) + if len(multimodal_embeddings) == 0: + return inputs_embeds - if flattened.shape[0] != num_expected_tokens: + mm_embeds_flat = _flatten_embeddings(multimodal_embeddings) + input_dtype = inputs_embeds.dtype + + try: + # For debugging + # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype) + + # NOTE: This can avoid D2H sync (#22105), but fails to + # raise an error if is_multimodal.sum() < len(mm_embeds_flat) + inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), + mm_embeds_flat.to(dtype=input_dtype)) + except RuntimeError as e: + num_actual_tokens = len(mm_embeds_flat) + num_expected_tokens = is_multimodal.sum().item() + + if num_actual_tokens != num_expected_tokens: expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( - f"Attempted to assign {expr} = {flattened.shape[0]} " + f"Attempted to assign {expr} = {num_actual_tokens} " f"multimodal tokens to {num_expected_tokens} placeholders" ) from e - else: - raise ValueError("Error during masked scatter operation") from e + + raise ValueError("Error during masked scatter operation") from e return inputs_embeds -def embed_multimodal( - input_ids: torch.Tensor, - multimodal_token_id: int, - get_text_embeds: Callable[[torch.Tensor], torch.Tensor], - multimodal_embeds: NestedTensors, -) -> torch.Tensor: - """ - Embed token IDs and multimodal inputs and combine their embeddings. - - ``multimodal_token_id`` is used to determine whether a token ID should - be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``. - - Compared to ``merge_multimodal_embeddings`, this avoids running - ``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]`` - which causes issues when the placeholder token ID exceeds the - vocabulary size of the language model. - """ - is_multimodal = input_ids == multimodal_token_id - is_text = ~is_multimodal - - text_embeds = get_text_embeds(input_ids[is_text]) - merged_embeds = torch.empty( - (input_ids.shape[0], text_embeds.shape[1]), - dtype=text_embeds.dtype, - device=text_embeds.device, - ) - - merged_embeds[is_text] = text_embeds - - return _merge_multimodal_embeddings( - merged_embeds, - is_multimodal, - multimodal_embeds, - ) - - def merge_multimodal_embeddings( input_ids: torch.Tensor, inputs_embeds: torch.Tensor, @@ -491,23 +465,29 @@ def merge_multimodal_embeddings( This updates ``inputs_embeds`` in place. """ if isinstance(placeholder_token_id, list): - placeholder_token_id = torch.tensor( - placeholder_token_id, - pin_memory=is_pin_memory_available()).to(device=input_ids.device, - non_blocking=True) - return _merge_multimodal_embeddings( - inputs_embeds, - torch.isin(input_ids, placeholder_token_id), - multimodal_embeddings, - ) + is_multimodal = isin_list(input_ids, placeholder_token_id) + else: + is_multimodal = (input_ids == placeholder_token_id) return _merge_multimodal_embeddings( inputs_embeds, - (input_ids == placeholder_token_id), - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, ) +def isin_list( + elements: torch.Tensor, + test_elements_list: list[int], +) -> torch.Tensor: + test_elements = torch.tensor( + test_elements_list, + pin_memory=is_pin_memory_available(), + ).to(device=elements.device, non_blocking=True) + + return torch.isin(elements, test_elements) + + class LayerFn(Protocol): def __call__(self, prefix: str) -> torch.nn.Module: diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index b33e8d09c4be1..f93e7ccfd06ff 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -45,10 +45,8 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsTranscription) -from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription +from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix logger = init_logger(__name__) @@ -376,9 +374,14 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + audio_encoder = self.tokenizer.instruct.audio_encoder + audio_tok_id = audio_encoder.audio_token audio_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - audio_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + audio_embeddings, + is_multimodal=input_ids == audio_tok_id, + ) input_ids = None hidden_states = self.language_model.model(input_ids, @@ -421,20 +424,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, return audio_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - audio_encoder = self.tokenizer.instruct.audio_encoder - audio_tok_id = audio_encoder.audio_token - - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, audio_tok_id) - return inputs_embeds - def _parse_and_validate_audio_arrays( self, **kwargs: object) -> Union[list[torch.Tensor], None]: audio_arrays = kwargs.pop("audio_arrays", None) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index de3e4f0592a62..7beeeddf988fe 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -579,10 +579,7 @@ class WhisperDecoder(nn.Module): hidden_states = self.layer_norm(hidden_states) return hidden_states - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -916,7 +913,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: # This method just returns the decoder sequence embeddings since # Whisper does not have encoder text tokens. diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 51e54e0dc337f..1b5bafb9ca1b1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -18,6 +18,7 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -64,8 +65,10 @@ class EagleProposer: # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() - self.is_multimodal_model = vllm_config.model_config \ - .is_multimodal_model + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + vllm_config.model_config) self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None @@ -175,7 +178,8 @@ class EagleProposer: last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embeds: Optional[list[torch.Tensor]] = None, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], + torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -219,18 +223,21 @@ class EagleProposer: # copy inputs to buffer for cudagraph self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states - if self.is_multimodal_model: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = self.model.get_input_embeddings( - input_ids, - multimodal_embeddings=mm_embeds or None, + + if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + + self.inputs_embeds[:num_tokens] = self.model.get_input_embeddings( + self.input_ids[:num_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) - self.inputs_embeds[:num_tokens] = inputs_embeds - inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] else: - inputs_embeds = None input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None with set_forward_context(per_layer_attn_metadata, self.vllm_config, @@ -372,14 +379,15 @@ class EagleProposer: self.input_ids[:batch_size] = input_ids self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states - if self.is_multimodal_model: - inputs_embeds = self.model.get_input_embeddings(input_ids) - self.inputs_embeds[:batch_size] = inputs_embeds - inputs_embeds = self.inputs_embeds[:input_batch_size] + if self.supports_mm_inputs: + self.inputs_embeds[:batch_size] = \ + self.model.get_input_embeddings(input_ids) + input_ids = None + inputs_embeds = self.inputs_embeds[:input_batch_size] else: - inputs_embeds = None input_ids = self.input_ids[:input_batch_size] + inputs_embeds = None # Run the model. with set_forward_context(per_layer_attn_metadata, @@ -849,7 +857,7 @@ class EagleProposer: self.attn_layer_names = list(draft_attn_layer_names) - if self.is_multimodal_model: + if self.supports_mm_inputs: # Even if the target model is multimodal, we can also use # text-only draft models try: @@ -861,7 +869,7 @@ class EagleProposer: logger.warning( "Draft model does not support multimodal inputs, " "falling back to text-only mode") - self.is_multimodal_model = False + self.supports_mm_inputs = False if supports_multimodal(target_model): # handle multimodality @@ -933,7 +941,7 @@ class EagleProposer: ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - if self.is_multimodal_model: + if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 22a177dd7cc7a..1bae0d4ce4d1f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -368,6 +368,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int64) + # Only relevant for multimodal models + if self.supports_mm_inputs: + self.is_mm_embed = self._make_buffer(self.max_num_tokens, + dtype=torch.bool) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy @@ -1627,9 +1632,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self, scheduler_output: "SchedulerOutput", shift_computed_tokens: int = 0, - ) -> list[torch.Tensor]: + ) -> tuple[list[torch.Tensor], torch.Tensor]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + + mm_embeds = list[torch.Tensor]() + is_mm_embed = self.is_mm_embed.cpu + is_mm_embed[:total_num_scheduled_tokens] = False + + req_start_idx = 0 should_sync_mrope_positions = False - mm_embeds: list[torch.Tensor] = [] + for req_id in self.input_batch.req_ids: mm_embeds_req: list[torch.Tensor] = [] @@ -1638,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] num_computed_tokens = \ req_state.num_computed_tokens + shift_computed_tokens + for mm_feature in req_state.mm_features: pos_info = mm_feature.mm_position start_pos = pos_info.offset @@ -1670,6 +1683,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] + req_start_pos = req_start_idx + start_pos - num_computed_tokens + is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ + = True if is_embed is None else is_embed + mm_embeds_item = gather_mm_placeholders( encoder_output[start_idx:end_idx], is_embed=is_embed, @@ -1677,6 +1694,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_embeds_req.append(mm_embeds_item) if self.is_multimodal_pruning_enabled and self.uses_mrope: + assert req_state.mrope_positions is not None should_sync_mrope_positions = True mm_embeds_req, new_mrope_positions, new_delta = ( self.model.recompute_mrope_positions( @@ -1685,18 +1703,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mrope_positions=req_state.mrope_positions, num_computed_tokens=req_state.num_computed_tokens, )) - assert req_state.mrope_positions is not None req_state.mrope_positions.copy_(new_mrope_positions) req_state.mrope_position_delta = new_delta mm_embeds.extend(mm_embeds_req) + req_start_idx += num_scheduled_tokens + + is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens) if should_sync_mrope_positions: self._calc_mrope_positions(scheduler_output) - self.mrope_positions.copy_to_gpu( - scheduler_output.total_num_scheduled_tokens) + self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens) - return mm_embeds + return mm_embeds, is_mm_embed def _extract_encoder_inputs( self, @@ -1990,14 +2009,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): and not self.model_config.is_encoder_decoder): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings( + scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. inputs_embeds_scheduled = self.model.get_input_embeddings( - input_ids=self.input_ids.gpu[:num_scheduled_tokens], - multimodal_embeddings=mm_embeds or None, + self.input_ids.gpu[:num_scheduled_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) # TODO(woosuk): Avoid the copy. Optimize. @@ -2586,10 +2607,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] - mm_embeds = None + if self.supports_mm_inputs: - mm_embeds = self._gather_mm_embeddings(scheduler_output, - shift_computed_tokens=1) + mm_embed_inputs = self._gather_mm_embeddings( + scheduler_output, + shift_computed_tokens=1, + ) + else: + mm_embed_inputs = None draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, @@ -2599,8 +2624,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): last_token_indices=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, - mm_embeds=mm_embeds, + mm_embed_inputs=mm_embed_inputs, ) + return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index a330f50875a89..2405f978ca73f 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -263,6 +263,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + # Only relevant for multimodal models + if self.supports_mm_inputs: + self.is_mm_embed_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.bool, + device="cpu", + pin_memory=self.pin_memory) + # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens # Keep in int64 to avoid overflow with long context @@ -879,13 +886,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", - ) -> list[torch.Tensor]: - mm_embeds: list[torch.Tensor] = [] + ) -> tuple[list[torch.Tensor], torch.Tensor]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + padded_total_num_scheduled_tokens = _get_padded_token_len( + self.num_tokens_paddings, total_num_scheduled_tokens) + + is_mm_embed = self.is_mm_embed_cpu + is_mm_embed[:padded_total_num_scheduled_tokens] = False + mm_embeds = list[torch.Tensor]() + req_start_idx = 0 + for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens + # TODO unroll loop and assume/enforce --disable_chunked_mm_input # NOTE (NickLucche) here we diverge from logic in other runners, as # we assume to only have whole mm items to process. Hence we avoid @@ -906,26 +922,53 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # The encoder output is already processed and stored # in the decoder's KV cache. continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens, + ) + assert start_idx < end_idx + mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None,\ f"Encoder cache miss for {mm_hash}." + assert pos_info.is_embed is None, "Expected all positions to"\ " be contiguous and embeddings." - encoder_output = self.encoder_cache[mm_hash] - mm_embeds.append(encoder_output) - return mm_embeds - def _get_model_inputs(self, input_ids: torch.Tensor, - mm_embeds: list[torch.Tensor]): + req_start_pos = req_start_idx + start_pos - num_computed_tokens + is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ + = True + + # Only whole mm items are processed + mm_embeds.append(encoder_output) + + req_start_idx += num_scheduled_tokens + + is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens] \ + .to(self.device) + + return mm_embeds, is_mm_embed + + def _get_model_inputs( + self, + input_ids: torch.Tensor, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]], + ): if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. inputs_embeds = self.model.get_input_embeddings( - input_ids=input_ids, + input_ids, multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) + return None, inputs_embeds else: # For text-only models, we use token ids as input. @@ -953,9 +996,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.supports_mm_inputs: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) + mm_embed_inputs = self._gather_mm_embeddings(scheduler_output) else: - mm_embeds = [] + mm_embed_inputs = None + torch_xla.sync(wait=False) # Prepare inputs, the requests might be split into multiple # executions, combine the result of each execution. @@ -972,7 +1016,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata, logits_indices, padded_num_reqs, num_reqs,\ end_index = self._prepare_inputs(scheduler_output, start_index) input_ids, inputs_embeds = self._get_model_inputs( - self.input_ids, mm_embeds) + self.input_ids, mm_embed_inputs) torch_xla.sync(wait=False) # Run the decoder with set_forward_context( @@ -1325,9 +1369,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): hf_config.image_token_index placeholders_ids = placeholders_ids.to(self.device) + + mm_mask = torch.tensor([False] * num_tokens) + mm_mask[:items_size] = True + mm_mask = mm_mask.to(self.device) # Assign outputs or the graph will be cut short. - a, b = self._get_model_inputs(placeholders_ids, - [mm_embeds]) + a, b = self._get_model_inputs( + placeholders_ids, + mm_embed_inputs=([mm_embeds], mm_mask), + ) assert a is None torch_xla.sync(wait=False) @@ -1338,7 +1388,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=torch.int32, device="cpu") placeholders_ids = placeholders_ids.to(self.device) - a, b = self._get_model_inputs(placeholders_ids, []) + a, b = self._get_model_inputs( + placeholders_ids, + mm_embed_inputs=None, + ) assert a is None torch_xla.sync(wait=False)