From 97d1c99302df6f7eadc0d0b32ec174db69cb4421 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 13 Nov 2025 03:14:33 +0000 Subject: [PATCH] Rename clashing method names for vLLM model protocol (#27583) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/contributing/model/basic.md | 4 +- docs/contributing/model/multimodal.md | 6 +-- vllm/model_executor/models/apertus.py | 8 ++-- vllm/model_executor/models/arcee.py | 8 ++-- vllm/model_executor/models/arctic.py | 8 ++-- vllm/model_executor/models/aria.py | 6 +-- vllm/model_executor/models/aya_vision.py | 2 +- vllm/model_executor/models/baichuan.py | 8 ++-- vllm/model_executor/models/bailing_moe.py | 8 ++-- vllm/model_executor/models/bamba.py | 8 ++-- vllm/model_executor/models/bert.py | 14 +++--- vllm/model_executor/models/bert_with_rope.py | 6 +-- vllm/model_executor/models/blip2.py | 2 +- vllm/model_executor/models/bloom.py | 8 ++-- vllm/model_executor/models/chameleon.py | 8 ++-- vllm/model_executor/models/chatglm.py | 8 ++-- vllm/model_executor/models/clip.py | 12 +++--- vllm/model_executor/models/cohere2_vision.py | 2 +- vllm/model_executor/models/commandr.py | 8 ++-- vllm/model_executor/models/dbrx.py | 8 ++-- vllm/model_executor/models/deepseek_eagle.py | 6 +-- vllm/model_executor/models/deepseek_mtp.py | 6 +-- vllm/model_executor/models/deepseek_ocr.py | 4 +- vllm/model_executor/models/deepseek_v2.py | 8 ++-- vllm/model_executor/models/deepseek_vl2.py | 2 +- vllm/model_executor/models/dots1.py | 8 ++-- vllm/model_executor/models/dots_ocr.py | 6 +-- vllm/model_executor/models/ernie45_moe.py | 8 ++-- vllm/model_executor/models/ernie45_vl.py | 10 ++--- vllm/model_executor/models/ernie45_vl_moe.py | 8 ++-- vllm/model_executor/models/ernie_mtp.py | 6 +-- vllm/model_executor/models/exaone.py | 8 ++-- vllm/model_executor/models/exaone4.py | 8 ++-- vllm/model_executor/models/falcon.py | 8 ++-- vllm/model_executor/models/falcon_h1.py | 8 ++-- vllm/model_executor/models/fuyu.py | 2 +- vllm/model_executor/models/gemma.py | 8 ++-- vllm/model_executor/models/gemma2.py | 8 ++-- vllm/model_executor/models/gemma3.py | 8 ++-- vllm/model_executor/models/gemma3_mm.py | 2 +- vllm/model_executor/models/gemma3n.py | 12 +++--- vllm/model_executor/models/gemma3n_mm.py | 12 +++--- vllm/model_executor/models/glm4.py | 4 +- vllm/model_executor/models/glm4_1v.py | 4 +- vllm/model_executor/models/glm4_moe.py | 8 ++-- vllm/model_executor/models/glm4_moe_mtp.py | 6 +-- vllm/model_executor/models/glm4v.py | 4 +- vllm/model_executor/models/gpt2.py | 12 +++--- vllm/model_executor/models/gpt_bigcode.py | 8 ++-- vllm/model_executor/models/gpt_j.py | 8 ++-- vllm/model_executor/models/gpt_neox.py | 8 ++-- vllm/model_executor/models/gpt_oss.py | 8 ++-- vllm/model_executor/models/granite.py | 8 ++-- vllm/model_executor/models/granite_speech.py | 8 ++-- vllm/model_executor/models/granitemoe.py | 8 ++-- .../model_executor/models/granitemoehybrid.py | 8 ++-- .../model_executor/models/granitemoeshared.py | 8 ++-- vllm/model_executor/models/grok1.py | 8 ++-- vllm/model_executor/models/hunyuan_v1.py | 8 ++-- .../models/hyperclovax_vision.py | 2 +- vllm/model_executor/models/idefics3.py | 6 +-- vllm/model_executor/models/interfaces.py | 32 ++++++++------ vllm/model_executor/models/interfaces_base.py | 43 ++++++++++--------- vllm/model_executor/models/internlm2.py | 8 ++-- vllm/model_executor/models/interns1.py | 8 ++-- vllm/model_executor/models/internvl.py | 8 ++-- vllm/model_executor/models/jais.py | 8 ++-- vllm/model_executor/models/jamba.py | 8 ++-- vllm/model_executor/models/keye.py | 4 +- vllm/model_executor/models/kimi_linear.py | 8 ++-- vllm/model_executor/models/kimi_vl.py | 2 +- vllm/model_executor/models/lfm2.py | 8 ++-- vllm/model_executor/models/lfm2_moe.py | 8 ++-- vllm/model_executor/models/llama.py | 8 ++-- vllm/model_executor/models/llama4_eagle.py | 6 +-- vllm/model_executor/models/llama_eagle.py | 6 +-- vllm/model_executor/models/llama_eagle3.py | 8 ++-- vllm/model_executor/models/llava.py | 2 +- vllm/model_executor/models/llava_next.py | 8 ++-- .../model_executor/models/llava_next_video.py | 2 +- vllm/model_executor/models/llava_onevision.py | 2 +- vllm/model_executor/models/longcat_flash.py | 8 ++-- vllm/model_executor/models/mamba.py | 8 ++-- vllm/model_executor/models/mamba2.py | 8 ++-- vllm/model_executor/models/midashenglm.py | 2 +- vllm/model_executor/models/mimo.py | 2 +- vllm/model_executor/models/mimo_mtp.py | 6 +-- vllm/model_executor/models/minicpm.py | 8 ++-- vllm/model_executor/models/minicpm_eagle.py | 8 ++-- vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/minimax_m2.py | 8 ++-- vllm/model_executor/models/minimax_text_01.py | 6 +-- vllm/model_executor/models/minimax_vl_01.py | 6 +-- vllm/model_executor/models/mistral3.py | 2 +- vllm/model_executor/models/mixtral.py | 8 ++-- vllm/model_executor/models/mllama4.py | 2 +- vllm/model_executor/models/modernbert.py | 14 +++--- vllm/model_executor/models/molmo.py | 4 +- vllm/model_executor/models/mpt.py | 8 ++-- .../model_executor/models/nano_nemotron_vl.py | 6 +-- vllm/model_executor/models/nemotron.py | 8 ++-- vllm/model_executor/models/nemotron_h.py | 8 ++-- vllm/model_executor/models/nemotron_nas.py | 8 ++-- vllm/model_executor/models/nemotron_vl.py | 8 ++-- vllm/model_executor/models/olmo.py | 8 ++-- vllm/model_executor/models/olmo2.py | 6 +-- vllm/model_executor/models/olmoe.py | 8 ++-- vllm/model_executor/models/openpangu.py | 8 ++-- vllm/model_executor/models/openpangu_mtp.py | 4 +- vllm/model_executor/models/opt.py | 12 +++--- vllm/model_executor/models/orion.py | 8 ++-- vllm/model_executor/models/ouro.py | 8 ++-- vllm/model_executor/models/ovis.py | 2 +- vllm/model_executor/models/ovis2_5.py | 2 +- vllm/model_executor/models/paddleocr_vl.py | 6 +-- vllm/model_executor/models/paligemma.py | 2 +- vllm/model_executor/models/persimmon.py | 8 ++-- vllm/model_executor/models/phi.py | 8 ++-- vllm/model_executor/models/phi3v.py | 8 ++-- vllm/model_executor/models/phi4_multimodal.py | 2 +- vllm/model_executor/models/phi4mm.py | 2 +- vllm/model_executor/models/phimoe.py | 8 ++-- vllm/model_executor/models/pixtral.py | 2 +- vllm/model_executor/models/plamo2.py | 8 ++-- vllm/model_executor/models/qwen.py | 4 +- vllm/model_executor/models/qwen2.py | 8 ++-- .../models/qwen2_5_omni_thinker.py | 10 ++--- vllm/model_executor/models/qwen2_5_vl.py | 2 +- vllm/model_executor/models/qwen2_audio.py | 2 +- vllm/model_executor/models/qwen2_moe.py | 8 ++-- vllm/model_executor/models/qwen2_rm.py | 4 +- vllm/model_executor/models/qwen2_vl.py | 2 +- vllm/model_executor/models/qwen3.py | 4 +- vllm/model_executor/models/qwen3_moe.py | 8 ++-- vllm/model_executor/models/qwen3_next.py | 8 ++-- vllm/model_executor/models/qwen3_next_mtp.py | 8 ++-- .../models/qwen3_omni_moe_thinker.py | 12 +++--- vllm/model_executor/models/qwen3_vl.py | 14 +++--- vllm/model_executor/models/qwen3_vl_moe.py | 2 +- vllm/model_executor/models/qwen_vl.py | 2 +- vllm/model_executor/models/roberta.py | 4 +- vllm/model_executor/models/seed_oss.py | 8 ++-- vllm/model_executor/models/siglip.py | 10 ++--- vllm/model_executor/models/skyworkr1v.py | 8 ++-- vllm/model_executor/models/solar.py | 8 ++-- vllm/model_executor/models/stablelm.py | 8 ++-- vllm/model_executor/models/starcoder2.py | 8 ++-- vllm/model_executor/models/step3_text.py | 8 ++-- vllm/model_executor/models/step3_vl.py | 12 +++--- vllm/model_executor/models/tarsier.py | 6 +-- vllm/model_executor/models/teleflm.py | 2 +- vllm/model_executor/models/terratorch.py | 2 +- .../models/transformers/base.py | 4 +- .../models/transformers/multimodal.py | 2 +- vllm/model_executor/models/ultravox.py | 8 ++-- vllm/model_executor/models/utils.py | 2 +- vllm/model_executor/models/voxtral.py | 2 +- vllm/model_executor/models/whisper.py | 10 ++--- vllm/model_executor/models/zamba2.py | 8 ++-- vllm/multimodal/processing.py | 2 +- vllm/v1/spec_decode/eagle.py | 10 ++--- vllm/v1/worker/gpu_model_runner.py | 10 ++--- vllm/v1/worker/tpu_model_runner.py | 22 +++++----- vllm/v1/worker/utils.py | 8 ++-- 164 files changed, 574 insertions(+), 583 deletions(-) diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md index 795bd5507a613..a7b54f015c2da 100644 --- a/docs/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -56,13 +56,13 @@ The initialization code should look like this: ### Computation Code -- Add a `get_input_embeddings` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model. +- Add a `embed_input_ids` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model. ```python class MyModel(nn.Module): ... - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: ... ``` diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 4e74afc688cf7..c2ca199220a1a 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -36,7 +36,7 @@ Further update the model as follows: More conveniently, you can simply pass `**kwargs` to the [forward][torch.nn.Module.forward] method and retrieve the keyword parameters for multimodal inputs from it. -- Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs. +- Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs. ??? code @@ -49,7 +49,7 @@ Further update the model as follows: image_features = self.vision_encoder(image_input) return self.multi_modal_projector(image_features) - def get_multimodal_embeddings( + def embed_multimodal( self, **kwargs: object, ) -> MultiModalEmbeddings | None: @@ -69,7 +69,7 @@ Further update the model as follows: !!! note By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in [PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing. - This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings]. + This logic can be found at [embed_input_ids][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids]. You may override this method if additional logic is required for your model when merging embeddings. diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 233b8c79f2992..0a8f21abb0a35 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -382,7 +382,7 @@ class ApertusModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -396,7 +396,7 @@ class ApertusModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -557,8 +557,8 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): vllm_config=vllm_config, prefix=prefix, layer_type=layer_type ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index f33970aff279c..20c3ff0754506 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -239,7 +239,7 @@ class ArceeModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -254,7 +254,7 @@ class ArceeModel(nn.Module): hidden_states = ( inputs_embeds if inputs_embeds is not None - else self.get_input_embeddings(input_ids) + else self.embed_input_ids(input_ids) ) residual = None else: @@ -423,8 +423,8 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): logits = self.logits_processor(self.lm_head, hidden_states) return logits - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights into the model (delegates to inner model and handles diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index ae3b96c83509d..b5cc07a56535d 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -442,7 +442,7 @@ class ArcticModel(nn.Module): ["hidden_states"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -456,7 +456,7 @@ class ArcticModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -496,8 +496,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index fe37487d6ed88..3d07e6b612ca3 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -613,7 +613,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -629,8 +629,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): **kwargs: object, ) -> torch.Tensor | IntermediateTensors: if inputs_embeds is None: - multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings( + multimodal_embeddings = self.embed_multimodal(**kwargs) + inputs_embeds = self.embed_input_ids( input_ids, multimodal_embeddings, is_multimodal=input_ids == self.config.image_token_index, diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 839ab5947e094..0ada2ed5028bb 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -417,7 +417,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index dac012eb9f829..8991ef4c606b6 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -309,7 +309,7 @@ class BaiChuanModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -323,7 +323,7 @@ class BaiChuanModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -426,8 +426,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 641bdb69c366c..a878134022565 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -438,7 +438,7 @@ class BailingMoeModel(nn.Module): else: self.norm = PPMissingLayer() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.word_embeddings(input_ids) def forward( @@ -452,7 +452,7 @@ class BailingMoeModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -608,8 +608,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 4a2b3da1c194d..e0a2defd5127e 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -314,7 +314,7 @@ class BambaModel(nn.Module): self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -328,7 +328,7 @@ class BambaModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -493,8 +493,8 @@ class BambaForCausalLM( self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 1c2334a785437..2679448bce775 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -375,7 +375,7 @@ class BertModel(nn.Module, SupportsQuant): self.embeddings = embedding_class(self.config) self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings.word_embeddings(input_ids) def forward( @@ -486,8 +486,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): ) self.pooler = self._build_pooler(pooler_config) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, @@ -835,8 +835,8 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu } ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.bert.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.bert.embed_input_ids(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) @@ -893,8 +893,8 @@ class BertForTokenClassification(nn.Module): } ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.bert.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.bert.embed_input_ids(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 31fdc4d21245a..131cb68914cf3 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -463,7 +463,7 @@ class BertWithRope(nn.Module, SupportsQuant): ) self.pooler = BertPooler(self.config) if add_pooling_layer else None - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) def forward( @@ -714,8 +714,8 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): loaded_params = loader.load_weights(weights) return loaded_params - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.new.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.new.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 2986a72f2e487..f71b9c01d359d 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -630,7 +630,7 @@ class Blip2ForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 18b09ee43b7b0..00fba93423d8e 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -271,7 +271,7 @@ class BloomModel(nn.Module): ["hidden_states"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.word_embeddings(input_ids) def forward( @@ -285,7 +285,7 @@ class BloomModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) hidden_states = self.word_embeddings_layernorm(hidden_states) else: assert intermediate_tensors is not None @@ -353,8 +353,8 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): self.transformer.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 64f73e938bf6c..fb7476c45fcdb 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -886,7 +886,7 @@ class ChameleonModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor: @@ -912,7 +912,7 @@ class ChameleonModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -998,7 +998,7 @@ class ChameleonForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -1006,7 +1006,7 @@ class ChameleonForConditionalGeneration( image_tokens = self.model.get_image_tokens( image_input["data"].to(self.config.dtype) ) - vision_embeddings = self.model.get_input_embeddings(image_tokens) + vision_embeddings = self.model.embed_input_ids(image_tokens) return vision_embeddings def forward( diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index ccf7c93001664..5d6f5e9125a28 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -353,7 +353,7 @@ class ChatGLMModel(nn.Module, SupportsQuant): self.encoder.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embedding(input_ids) def forward( @@ -368,7 +368,7 @@ class ChatGLMModel(nn.Module, SupportsQuant): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -451,8 +451,8 @@ class ChatGLMBaseModel(nn.Module): self.transformer.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.embed_input_ids(input_ids) def compute_logits( self, diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 27953c27188d9..50f476dfd185b 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -561,7 +561,7 @@ class CLIPTextTransformer(nn.Module): eps=config.layer_norm_eps, ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings.token_embedding(input_ids) def forward( @@ -842,7 +842,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): } ) - # Assumes that self.forward is called after self.get_input_embeddings + # Assumes that self.forward is called after self.embed_input_ids self._is_text_input = True def get_text_features( @@ -903,7 +903,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): def get_language_model(self) -> torch.nn.Module: return self.text_model - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -917,16 +917,16 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, handle_oov_mm_token=handle_oov_mm_token, ) - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 19cc31c9bd18b..139ccba9df6d8 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -439,7 +439,7 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 6ae1dc3560827..77bb178519813 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -311,7 +311,7 @@ class CohereModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -325,7 +325,7 @@ class CohereModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -436,8 +436,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) @torch.no_grad() def forward( diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 70999501f4c69..528ef4f76742d 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -354,7 +354,7 @@ class DbrxModel(nn.Module): ["hidden_states"], config.d_model ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) def forward( @@ -368,7 +368,7 @@ class DbrxModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] @@ -455,8 +455,8 @@ class DbrxForCausalLM(nn.Module, SupportsPP): self.transformer.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index fd2f20ea501d0..9e834a73f8e5e 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -73,7 +73,7 @@ class DeepseekV2Model(nn.Module): self.hnorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -222,8 +222,8 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): self.num_moe_layers = self.config.num_hidden_layers self.set_moe_parameters() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 26b9c25e6bdb5..e028dc497aa6a 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -142,7 +142,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ) self.logits_processor = LogitsProcessor(config.vocab_size) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -206,8 +206,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts): self.moe_layers.append(layer.mlp.experts) self.extract_moe_parameters(example_moe) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index 0432567521843..c89caab93a1ee 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -557,9 +557,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object - ) -> MultiModalEmbeddings | None: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 38189e17f7d8b..115818d903a6d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1236,7 +1236,7 @@ class DeepseekV2Model(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -1250,7 +1250,7 @@ class DeepseekV2Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -1389,8 +1389,8 @@ class DeepseekV2ForCausalLM( self.extract_moe_parameters(example_moe) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 306eef3dca990..e7b48e0f4e554 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -619,7 +619,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 15caa3184581d..d24da0c42a254 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -398,7 +398,7 @@ class Dots1Model(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -412,7 +412,7 @@ class Dots1Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -541,8 +541,8 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 1b2bb60a17c16..25e5588961a63 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -840,7 +840,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -858,8 +858,8 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings( + vision_embeddings = self.embed_multimodal(**kwargs) + inputs_embeds = self.embed_input_ids( input_ids, vision_embeddings, is_multimodal=input_ids == self.config.image_token_id, diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index b35666175ea7b..f2999968669f6 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -465,7 +465,7 @@ class Ernie4_5_MoeModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -479,7 +479,7 @@ class Ernie4_5_MoeModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -726,8 +726,8 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index c040b19bba20e..daa5bf03ea4a9 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -1656,9 +1656,7 @@ class Ernie4_5_VLMoeForConditionalGeneration( return modalities - def get_multimodal_embeddings( - self, **kwargs: object - ) -> MultiModalEmbeddings | None: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return None @@ -1681,7 +1679,7 @@ class Ernie4_5_VLMoeForConditionalGeneration( return multimodal_embeddings - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -1694,9 +1692,9 @@ class Ernie4_5_VLMoeForConditionalGeneration( # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index d002d1838c8ea..e8ef86f9b7f01 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -561,7 +561,7 @@ class Ernie4_5_VLMoeModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -577,7 +577,7 @@ class Ernie4_5_VLMoeModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -642,8 +642,8 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index e7036840388cc..1b9abc3572a3b 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -112,7 +112,7 @@ class ErnieMultiTokenPredictor(nn.Module): ) self.logits_processor = LogitsProcessor(config.vocab_size) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -160,8 +160,8 @@ class ErnieMTP(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index b9c7a520caffb..6c56bfc433c7a 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -357,7 +357,7 @@ class ExaoneModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) def forward( @@ -371,7 +371,7 @@ class ExaoneModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -512,8 +512,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.transformer.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 6a5c888c095ae..b89e168ada20e 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -344,7 +344,7 @@ class Exaone4Model(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -358,7 +358,7 @@ class Exaone4Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -498,8 +498,8 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 1b9c7da334909..85acdff3d96b4 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -399,7 +399,7 @@ class FalconModel(nn.Module): ["hidden_states"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.word_embeddings(input_ids) def forward( @@ -413,7 +413,7 @@ class FalconModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for layer in islice(self.h, self.start_layer, self.end_layer): @@ -515,8 +515,8 @@ class FalconForCausalLM(nn.Module, SupportsPP): self.transformer.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 38838be29093e..3653425b8e1ca 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -461,7 +461,7 @@ class FalconH1Model(nn.Module): else: self.final_layernorm = PPMissingLayer() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -476,7 +476,7 @@ class FalconH1Model(nn.Module): hidden_states = inputs_embeds * self.embedding_multiplier else: hidden_states = ( - self.get_input_embeddings(input_ids) * self.embedding_multiplier + self.embed_input_ids(input_ids) * self.embedding_multiplier ) else: assert intermediate_tensors is not None @@ -601,8 +601,8 @@ class FalconH1ForCausalLM( self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 005fac4b1f05d..269c36ab5b9c7 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -333,7 +333,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index caeee7c2e1ecc..7aaae7c503b58 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -293,7 +293,7 @@ class GemmaModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -307,7 +307,7 @@ class GemmaModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) hidden_states *= self.normalizer residual = None else: @@ -396,8 +396,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index efd01535fc3ef..4d5d6cbb37c62 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -290,7 +290,7 @@ class Gemma2Model(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -304,7 +304,7 @@ class Gemma2Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) hidden_states *= self.normalizer residual = None else: @@ -409,8 +409,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 213f9f562f8a0..357e61a4e78bf 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -393,7 +393,7 @@ class Gemma3Model(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: # NOTE(woosuk): Only apply the normalizer to the output of # vocab embedding. Don't apply it to the vision embedding. return self.embed_tokens(input_ids) * self.normalizer @@ -410,7 +410,7 @@ class Gemma3Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -540,8 +540,8 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 8e1dbd9e2cea7..02fb7ef31dc94 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -596,7 +596,7 @@ class Gemma3ForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 22d51ab762692..64443190f53ed 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -685,7 +685,7 @@ class Gemma3nSelfDecoder(nn.Module): per_layer_inputs = per_layer_projection return per_layer_inputs - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) * self.embed_scale def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor: @@ -712,7 +712,7 @@ class Gemma3nSelfDecoder(nn.Module): if inputs_embeds is not None: hidden_states_0 = inputs_embeds else: - hidden_states_0 = self.get_input_embeddings(input_ids) + hidden_states_0 = self.embed_input_ids(input_ids) adjusted_per_layer_inputs = self.get_per_layer_inputs( hidden_states_0, per_layer_inputs @@ -881,8 +881,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): def get_per_layer_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.self_decoder.get_per_layer_input_embeddings(input_ids) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.self_decoder.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.self_decoder.embed_input_ids(input_ids) def fast_prefill_forward( self, @@ -1125,8 +1125,8 @@ class Gemma3nForCausalLM(nn.Module): config.vocab_size, soft_cap=config.final_logit_softcapping ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 2b727a538bf25..6ae76976eb46c 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -645,7 +645,7 @@ class Gemma3nForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if mm_input_by_modality is None: return [] @@ -664,7 +664,7 @@ class Gemma3nForConditionalGeneration( multimodal_embeddings.extend(audio_embeddings) return multimodal_embeddings - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, @@ -689,9 +689,9 @@ class Gemma3nForConditionalGeneration( # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, @@ -709,10 +709,10 @@ class Gemma3nForConditionalGeneration( if intermediate_tensors is not None: inputs_embeds = None - # NOTE (NickLucche) During profiling, `get_input_embeddings` is not + # NOTE (NickLucche) During profiling, `embed_input_ids` is not # called, hence we don't have input_ids to compute PLEs. We simply # select a chunk of pre-allocated PLEs. During normal execution, - # `get_input_embeddings` is called before forward, hence this slice + # `embed_input_ids` is called before forward, hence this slice # will contain PLEs computed from the actual input_ids. per_layer_inputs = self.per_layer_embeddings[: inputs_embeds.shape[0]] diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index 4172f16737c18..faa0674a2e43d 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -275,8 +275,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 60cad2e2907f2..b2d4fe0c0139b 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -1594,9 +1594,7 @@ class Glm4vForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object - ) -> MultiModalEmbeddings | None: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index b30bd66161da9..1422dbe9b3cd0 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -455,7 +455,7 @@ class Glm4MoeModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -469,7 +469,7 @@ class Glm4MoeModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -704,8 +704,8 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, Glm4MixtureOfExper self.extract_moe_parameters(example_moe) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index 9db2aaa075de1..110ed0a646334 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -149,7 +149,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module): ) self.logits_processor = LogitsProcessor(config.vocab_size) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -211,8 +211,8 @@ class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts): self.moe_layers.append(layer.mlp.experts) self.extract_moe_parameters(example_moe) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 899797a510539..1c18ea0745f2b 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -756,9 +756,9 @@ class GLM4VForCausalLM( def get_language_model(self) -> torch.nn.Module: return self.transformer - get_input_embeddings = SupportsMultiModal.get_input_embeddings + embed_input_ids = SupportsMultiModal.embed_input_ids - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 6d99d02a32be2..a5e8131c7fba9 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -213,7 +213,7 @@ class GPT2Model(nn.Module): ["hidden_states"], config.n_embd ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) def forward( @@ -225,7 +225,7 @@ class GPT2Model(nn.Module): ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.embed_input_ids(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds else: @@ -293,8 +293,8 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): self.transformer.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.embed_input_ids(input_ids) def forward( self, @@ -365,8 +365,8 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding): } ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.embed_input_ids(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 99cdaabb98dfe..cdf038ba25c92 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -230,7 +230,7 @@ class GPTBigCodeModel(nn.Module): ["hidden_states"], config.n_embd ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) def forward( @@ -242,7 +242,7 @@ class GPTBigCodeModel(nn.Module): ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.embed_input_ids(input_ids) hidden_states = inputs_embeds + self.wpe(position_ids) else: hidden_states = intermediate_tensors["hidden_states"] @@ -306,8 +306,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.transformer.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index e04b2465e54ae..e416ecde0c1e0 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -215,7 +215,7 @@ class GPTJModel(nn.Module): ["hidden_states"], config.n_embd ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) def forward( @@ -229,7 +229,7 @@ class GPTJModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for layer in islice(self.h, self.start_layer, self.end_layer): @@ -319,8 +319,8 @@ class GPTJForCausalLM(nn.Module, SupportsPP): self.transformer.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index e6c145602d29a..af0c9209231cb 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -229,7 +229,7 @@ class GPTNeoXModel(nn.Module): ["hidden_states"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_in(input_ids) def forward( @@ -243,7 +243,7 @@ class GPTNeoXModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for layer in islice(self.layers, self.start_layer, self.end_layer): @@ -317,8 +317,8 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): self.gpt_neox.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.gpt_neox.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.gpt_neox.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 9cb481fc30c79..692ef605fe175 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -269,7 +269,7 @@ class GptOssModel(nn.Module): ) self.aux_hidden_state_layers = tuple[int, ...]() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embedding(input_ids) def forward( @@ -283,7 +283,7 @@ class GptOssModel(nn.Module): if inputs_embeds is not None: x = inputs_embeds else: - x = self.get_input_embeddings(input_ids) + x = self.embed_input_ids(input_ids) residual = None else: @@ -703,8 +703,8 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 5fc8718ca75e5..c44b4021471ef 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -318,7 +318,7 @@ class GraniteModel(nn.Module): else: self.norm = PPMissingLayer() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -332,7 +332,7 @@ class GraniteModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) hidden_states *= self.config.embedding_multiplier else: @@ -473,8 +473,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): else: self.lm_head = PPMissingLayer() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index 3ddf02bbba2ea..1797adab8d146 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -767,7 +767,7 @@ class GraniteSpeechForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( + def embed_multimodal( self, **kwargs: object, ) -> MultiModalEmbeddings: @@ -779,7 +779,7 @@ class GraniteSpeechForConditionalGeneration( audio_features = self._process_audio_input(audio_input) return audio_features - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -790,9 +790,9 @@ class GraniteSpeechForConditionalGeneration( ) -> torch.Tensor: # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index c5b36c362ee32..5c6759ded0669 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -315,7 +315,7 @@ class GraniteMoeModel(nn.Module): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -329,7 +329,7 @@ class GraniteMoeModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) hidden_states *= self.embedding_multiplier else: assert intermediate_tensors is not None @@ -531,8 +531,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): scale=1 / self.config.logits_scaling, ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index ea49a0ffee011..05177f1d1ac2c 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -366,7 +366,7 @@ class GraniteMoeHybridModel(nn.Module): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -380,7 +380,7 @@ class GraniteMoeHybridModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) hidden_states = hidden_states * self.embedding_multiplier residual = None else: @@ -680,8 +680,8 @@ class GraniteMoeHybridForCausalLM( self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index e08e9f73ec879..926c539af33be 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -182,7 +182,7 @@ class GraniteMoeSharedModel(nn.Module): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -196,7 +196,7 @@ class GraniteMoeSharedModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) hidden_states *= self.embedding_multiplier else: assert intermediate_tensors is not None @@ -295,8 +295,8 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): scale=1 / self.config.logits_scaling, ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 0770e03b5356e..9dc231863f74f 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -334,7 +334,7 @@ class Grok1Model(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) hidden_states = hidden_states * self.embedding_multiplier_scale return hidden_states @@ -350,7 +350,7 @@ class Grok1Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -522,8 +522,8 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index a05a00932c13b..1eadcbe67ade3 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -643,7 +643,7 @@ class HunYuanModel(nn.Module): else: self.norm = PPMissingLayer() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -657,7 +657,7 @@ class HunYuanModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -987,8 +987,8 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP): ) return loader.load_weights(weights) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts): diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 3d28ba951b94e..db46353efde5c 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -732,7 +732,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( + def embed_multimodal( self, **kwargs: object, ) -> MultiModalEmbeddings: diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 06ca8c4886341..9c5f9389e54bb 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -550,8 +550,8 @@ class Idefics3Model(nn.Module): return image_hidden_states - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.text_model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.text_model.embed_input_ids(input_ids) def forward( self, @@ -674,7 +674,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 88b45bf07c0d8..929bfaaee5cbb 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -94,7 +94,7 @@ class SupportsMultiModal(Protocol): """ ... - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: """ Returns multimodal embeddings generated from multimodal kwargs to be merged with text embeddings. @@ -104,7 +104,13 @@ class SupportsMultiModal(Protocol): the appearances of their corresponding multimodal data item in the input prompt. """ - ... + if hasattr(self, "get_multimodal_embeddings"): + logger.warning_once( + "`get_multimodal_embeddings` for vLLM models is deprecated and will be " + "removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename " + "this method to `embed_multimodal`." + ) + return self.get_multimodal_embeddings(**kwargs) def get_language_model(self) -> VllmModel: """ @@ -119,10 +125,10 @@ class SupportsMultiModal(Protocol): ... @overload - def get_input_embeddings(self, input_ids: Tensor) -> Tensor: ... + def embed_input_ids(self, input_ids: Tensor) -> Tensor: ... @overload - def get_input_embeddings( + def embed_input_ids( self, input_ids: Tensor, multimodal_embeddings: MultiModalEmbeddings, @@ -131,17 +137,17 @@ class SupportsMultiModal(Protocol): handle_oov_mm_token: bool = False, ) -> Tensor: ... - def _get_text_embeddings( + def _embed_text_input_ids( self, input_ids: Tensor, - get_input_embeddings: Callable[[Tensor], Tensor], + embed_input_ids: Callable[[Tensor], Tensor], *, is_multimodal: Tensor | None, handle_oov_mm_token: bool, ) -> Tensor: if handle_oov_mm_token and is_multimodal is not None: is_text = ~is_multimodal - text_embeds = get_input_embeddings(input_ids[is_text]) + text_embeds = embed_input_ids(input_ids[is_text]) return torch.empty( (input_ids.shape[0], text_embeds.shape[1]), @@ -149,9 +155,9 @@ class SupportsMultiModal(Protocol): device=text_embeds.device, ).masked_scatter_(is_text.unsqueeze_(-1), text_embeds) - return get_input_embeddings(input_ids) + return embed_input_ids(input_ids) - def get_input_embeddings( + def embed_input_ids( self, input_ids: Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -167,15 +173,15 @@ class SupportsMultiModal(Protocol): In case the multi-modal token IDs exceed the vocabulary size of the language model, you can set `handle_oov_mm_token=False` - to avoid calling the language model's `get_input_embeddings` method + to avoid calling the language model's `embed_input_ids` method on those tokens. Note however that doing so increases memory usage as an additional buffer is needed to hold the input embeddings. """ from .utils import _merge_multimodal_embeddings - inputs_embeds = self._get_text_embeddings( + inputs_embeds = self._embed_text_input_ids( input_ids, - self.get_language_model().get_input_embeddings, + self.get_language_model().embed_input_ids, is_multimodal=is_multimodal, handle_oov_mm_token=handle_oov_mm_token, ) @@ -185,7 +191,7 @@ class SupportsMultiModal(Protocol): if is_multimodal is None: raise ValueError( - "`get_input_embeddings` now requires `is_multimodal` arg, " + "`embed_input_ids` now requires `is_multimodal` arg, " "please update your model runner according to " "https://github.com/vllm-project/vllm/pull/16229." ) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index d87a65a47083c..4267b6c6598e2 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -41,24 +41,19 @@ T_co = TypeVar("T_co", default=torch.Tensor, covariant=True) class VllmModel(Protocol[T_co]): """The interface required for all models in vLLM.""" - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: ... + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: ... - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: """Apply token embeddings to `input_ids`.""" - ... + if hasattr(self, "get_input_embeddings"): + logger.warning_once( + "`get_input_embeddings` for vLLM models is deprecated and will be " + "removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename " + "this method to `embed_input_ids`." + ) + return self.get_input_embeddings(input_ids) - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - ) -> T_co: ... + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> T_co: ... def _check_vllm_model_init(model: type[object] | object) -> bool: @@ -66,11 +61,19 @@ def _check_vllm_model_init(model: type[object] | object) -> bool: return supports_kw(model_init, "vllm_config") -def _check_vllm_model_get_input_embeddings(model: type[object] | object) -> bool: - model_get_input_embeddings = getattr(model, "get_input_embeddings", None) - if not callable(model_get_input_embeddings): +def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool: + model_embed_input_ids = getattr(model, "embed_input_ids", None) + if not callable(model_embed_input_ids): + model_get_input_embeddings = getattr(model, "get_input_embeddings", None) + if callable(model_get_input_embeddings): + logger.warning( + "`get_input_embeddings` for vLLM models is deprecated and will be " + "removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename " + "this method to `embed_input_ids`." + ) + model.embed_input_ids = model_get_input_embeddings logger.warning( - "The model (%s) is missing the `get_input_embeddings` method.", + "The model (%s) is missing the `embed_input_ids` method.", model, ) return False @@ -110,7 +113,7 @@ def is_vllm_model( ) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]: return ( _check_vllm_model_init(model) - and _check_vllm_model_get_input_embeddings(model) + and _check_vllm_model_embed_input_ids(model) and _check_vllm_model_forward(model) ) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index d856f5c79e33d..60fbeb842dd4b 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -284,7 +284,7 @@ class InternLM2Model(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) def forward( @@ -298,7 +298,7 @@ class InternLM2Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -350,8 +350,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index 1f251935a70a9..c2195fd0cb88d 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -742,7 +742,7 @@ class InternS1ForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -765,7 +765,7 @@ class InternS1ForConditionalGeneration( return multimodal_embeddings - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -778,9 +778,9 @@ class InternS1ForConditionalGeneration( # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index e2d2647f01777..ccbde115009d2 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -1344,7 +1344,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -1367,7 +1367,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) return multimodal_embeddings - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -1380,9 +1380,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 782ab6f1e2da2..5549a1fc1cd30 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -275,7 +275,7 @@ class JAISModel(nn.Module): ["hidden_states"], config.n_embd ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) def forward( @@ -287,7 +287,7 @@ class JAISModel(nn.Module): ) -> IntermediateTensors | torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.embed_input_ids(input_ids) if self.wpe is not None: position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds @@ -339,8 +339,8 @@ class JAISLMHeadModel(nn.Module, SupportsPP): self.transformer.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 70f52e3106f81..3a2c98c73dab4 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -340,7 +340,7 @@ class JambaModel(nn.Module): self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -354,7 +354,7 @@ class JambaModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -508,8 +508,8 @@ class JambaForCausalLM( self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 2998c87918a99..1eb0eccc0411c 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1484,9 +1484,7 @@ class BaseKeyeModule(nn.Module): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object - ) -> MultiModalEmbeddings | None: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return None diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py index cce22842d3330..f3675075a48f4 100644 --- a/vllm/model_executor/models/kimi_linear.py +++ b/vllm/model_executor/models/kimi_linear.py @@ -439,7 +439,7 @@ class KimiLinearModel(nn.Module): "num_attention_heads must be divisible by world_size" ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -454,7 +454,7 @@ class KimiLinearModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -504,8 +504,8 @@ class KimiLinearForCausalLM( self.config.vocab_size, scale=logit_scale ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index fa04f60b9c140..8167b82f32330 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -404,7 +404,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> NestedTensors | None: + def embed_multimodal(self, **kwargs: object) -> NestedTensors | None: # Validate the multimodal input keyword arguments image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 21d71887178e7..aeb25602f11a4 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -351,7 +351,7 @@ class Lfm2Model(nn.Module): else: self.embedding_norm = PPMissingLayer() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -365,7 +365,7 @@ class Lfm2Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -504,8 +504,8 @@ class Lfm2ForCausalLM( self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py index b191164671050..6b7b5564ee989 100644 --- a/vllm/model_executor/models/lfm2_moe.py +++ b/vllm/model_executor/models/lfm2_moe.py @@ -466,7 +466,7 @@ class Lfm2MoeModel(nn.Module): else: self.embedding_norm = PPMissingLayer() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -480,7 +480,7 @@ class Lfm2MoeModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -714,8 +714,8 @@ class Lfm2MoeForCausalLM( self.num_routed_experts = example_layer.n_routed_experts self.num_redundant_experts = example_layer.n_redundant_experts - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def update_physical_experts_metadata( self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0a08bd376badc..c49a1ea817f91 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -424,7 +424,7 @@ class LlamaModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -438,7 +438,7 @@ class LlamaModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -640,8 +640,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ): return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index b59176191e7aa..e8716d652415e 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -82,7 +82,7 @@ class LlamaModel(nn.Module): ) self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -93,7 +93,7 @@ class LlamaModel(nn.Module): inputs_embeds: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.embed_input_ids(input_ids) hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) residual = None for layer in self.layers: @@ -195,7 +195,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): def get_language_model(self) -> torch.nn.Module: return self.model - get_input_embeddings = SupportsMultiModal.get_input_embeddings # type: ignore + embed_input_ids = SupportsMultiModal.embed_input_ids # type: ignore def forward( self, diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 3617294bd621d..ab2a9f6f06dbe 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -84,7 +84,7 @@ class LlamaModel(nn.Module): self.config.hidden_size * 2, self.config.hidden_size, bias=False ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -158,8 +158,8 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): self.config.vocab_size, scale=logit_scale ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index b8b9cc76d08d2..6edc9519dfbbf 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -172,7 +172,7 @@ class LlamaModel(nn.Module): eps=self.config.rms_norm_eps, ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -183,7 +183,7 @@ class LlamaModel(nn.Module): input_embeds: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if input_embeds is None: - input_embeds = self.get_input_embeddings(input_ids) + input_embeds = self.embed_input_ids(input_ids) assert hidden_states.shape[-1] == input_embeds.shape[-1] residual = None @@ -261,13 +261,13 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): requires_grad=False, ) - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: NestedTensors | None = None, is_multimodal: torch.Tensor | None = None, ) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index a3dea0ce86f8e..c1fb2d4f4af7d 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -661,7 +661,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 3cf546644d04a..98b1b46045c3d 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -483,14 +483,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -501,9 +501,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ) -> torch.Tensor: # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 77c331b0182bd..902c598c226f0 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -422,7 +422,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: video_input = self._parse_and_validate_video_input(**kwargs) if video_input is None: return [] diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index c4cae240ea469..322bde94ff66d 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -866,7 +866,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index b848ae6e822f1..5de10e7086830 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -498,7 +498,7 @@ class FlashModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -512,7 +512,7 @@ class FlashModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -583,8 +583,8 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 02abe693e071d..aa16640a94276 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -135,7 +135,7 @@ class MambaModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) def forward( @@ -149,7 +149,7 @@ class MambaModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -218,8 +218,8 @@ class MambaForCausalLM( self.backbone.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.backbone.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.backbone.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index d19480b064e05..fc17f98be1986 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -131,7 +131,7 @@ class Mamba2Model(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) def forward( @@ -145,7 +145,7 @@ class Mamba2Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -257,8 +257,8 @@ class Mamba2ForCausalLM( self.backbone.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.backbone.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.backbone.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 322cce79d4cb2..a84c99059cd9c 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -791,7 +791,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.decoder - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index 666ac90c44293..cd0a6190e9502 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -70,7 +70,7 @@ class MiMoModel(Qwen2Model): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index 3d7695a2a3042..9905f65b74ca7 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -120,7 +120,7 @@ class MiMoMultiTokenPredictor(nn.Module): self.logits_processor = LogitsProcessor(config.vocab_size) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -164,8 +164,8 @@ class MiMoMTP(nn.Module): prefix=maybe_prefix(prefix, "lm_head"), ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index d9f0b477180e4..914b097fe199e 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -440,7 +440,7 @@ class MiniCPMModel(nn.Module): prefix=f"{prefix}.layers", ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) return embedding * self.config.scale_emb @@ -455,7 +455,7 @@ class MiniCPMModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: hidden_states = intermediate_tensors["hidden_states"] @@ -615,8 +615,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 6efc61e25ea1b..0ca31913485db 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -193,7 +193,7 @@ class EagleMiniCPMModel(nn.Module): ] ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) return embedding * self.config.scale_emb @@ -203,7 +203,7 @@ class EagleMiniCPMModel(nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor | IntermediateTensors: - input_embeds = self.get_input_embeddings(input_ids) + input_embeds = self.embed_input_ids(input_ids) input_embeds = self.input_norm1(input_embeds) hidden_states = self.input_norm2(hidden_states) @@ -354,8 +354,8 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): vllm_config=vllm_config, prefix=prefix, start_layer=start_layer ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 09937706f8c5d..2ac97764dd341 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1139,7 +1139,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.llm - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index 21ed428a05d0f..49d2f2d261969 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -360,7 +360,7 @@ class MiniMaxM2Model(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -374,7 +374,7 @@ class MiniMaxM2Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -510,8 +510,8 @@ class MiniMaxM2ForCausalLM(nn.Module, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 1409a309f3aeb..bf1ecc822756d 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -620,7 +620,7 @@ class MiniMaxText01Model(nn.Module): ) minimax_cache_tensors[:, slots_tensor, ...] = 0 - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -709,8 +709,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index fb7c6d42a0658..0939a72ba53ec 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -353,7 +353,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support raise AssertionError("This line should be unreachable.") - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -371,8 +371,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings( + vision_embeddings = self.embed_multimodal(**kwargs) + inputs_embeds = self.embed_input_ids( input_ids, vision_embeddings, is_multimodal=input_ids == self.config.image_token_index, diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 26d4deca2e120..1ddb470a0f93d 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -549,7 +549,7 @@ class Mistral3ForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c1f411b6cd2ac..d7a1cb82fb4fb 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -345,7 +345,7 @@ class MixtralModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -359,7 +359,7 @@ class MixtralModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -591,8 +591,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 4548abde77d5f..14e741f322582 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -865,7 +865,7 @@ class Llama4ForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 5a0769f3bdaae..3a8a6c74d9d15 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -46,7 +46,7 @@ class ModernBertEmbeddings(nn.Module): ) self.norm = nn.LayerNorm(config.hidden_size, eps=eps, bias=config.norm_bias) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) def forward( @@ -225,8 +225,8 @@ class ModernBertModel(nn.Module): config.hidden_size, eps=config.norm_eps, bias=config.norm_bias ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embeddings.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings.embed_input_ids(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.hf_to_vllm_mapper.apply(weights) @@ -337,8 +337,8 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): } ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): self_weights = [] @@ -424,8 +424,8 @@ class ModernBertForTokenClassification(nn.Module): } ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self, skip_prefixes=["drop"]) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 7a9e3d81b73a1..ab83a271e30a0 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -832,7 +832,7 @@ class MolmoModel(nn.Module, SupportsQuant): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -1491,7 +1491,7 @@ class MolmoForCausalLM( def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 29e887c4d9c98..106ad971a321a 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -248,7 +248,7 @@ class MPTModel(nn.Module): ["hidden_states"], config.d_model ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) def forward( @@ -262,7 +262,7 @@ class MPTModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -308,8 +308,8 @@ class MPTForCausalLM(nn.Module, SupportsPP): self.transformer.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 86fc1d6046cee..cb39c2ae482d2 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -655,7 +655,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): The replacement returned is not actually used to replace the placeholder tokens - it's just used to make sure we allocate the correct number of tokens. - Actual replacement is done in get_multimodal_embeddings of + Actual replacement is done in embed_multimodal of NemotronH_Nano_VL_V2 (specifically in _process_video_input -> _create_final_video_embeddings). There, we create the final embeddings with text embeddings for indicator tokens @@ -1401,7 +1401,7 @@ class NemotronH_Nano_VL_V2( # Create final video embeddings, merging text embeddings for indicator # tokens with video embeddings - text_embeddings = self.get_language_model().get_input_embeddings(repl_token_ids) + text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids) final_video_embeddings = _merge_multimodal_embeddings( inputs_embeds=text_embeddings, multimodal_embeddings=video_embeddings, @@ -1465,7 +1465,7 @@ class NemotronH_Nano_VL_V2( return modalities - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: # Validate the multimodal input keyword arguments modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if modalities is None: diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 17e8e7f28258d..92dcf5ea57008 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -351,7 +351,7 @@ class NemotronModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -365,7 +365,7 @@ class NemotronModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -491,8 +491,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 8ef3eee173eb2..f7e0caf410e10 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -548,7 +548,7 @@ class NemotronHModel(nn.Module): self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -562,7 +562,7 @@ class NemotronHModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -823,8 +823,8 @@ class NemotronHForCausalLM( moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index acd0d0c982348..b839206a3094d 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -291,7 +291,7 @@ class DeciModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -305,7 +305,7 @@ class DeciModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -461,8 +461,8 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): return DeciModel(vllm_config=vllm_config, prefix=prefix) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 2f78e2f60c93b..5a1dda8aac2c1 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -561,7 +561,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -580,7 +580,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor return multimodal_embeddings - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -593,9 +593,9 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index cb47f76a27ff5..487e3f671a455 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -268,7 +268,7 @@ class OlmoModel(nn.Module): ["hidden_states"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -285,7 +285,7 @@ class OlmoModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -379,8 +379,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 2aa01adebc9f1..045582c889ee4 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -304,7 +304,7 @@ class Olmo2Model(nn.Module): ["hidden_states"], self.config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -419,8 +419,8 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 35a09334a1293..499eb05de76e4 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -296,7 +296,7 @@ class OlmoeModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -310,7 +310,7 @@ class OlmoeModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -471,8 +471,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index bf1b7570a8828..d13a745beffeb 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -753,7 +753,7 @@ class OpenPanguModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -767,7 +767,7 @@ class OpenPanguModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -969,8 +969,8 @@ class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/openpangu_mtp.py b/vllm/model_executor/models/openpangu_mtp.py index f4049f2d39705..436b7f981b1f9 100644 --- a/vllm/model_executor/models/openpangu_mtp.py +++ b/vllm/model_executor/models/openpangu_mtp.py @@ -100,8 +100,8 @@ class OpenPanguMTP(nn.Module, SupportsPP): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index d124b7671b9cf..5df700d1a2e17 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -262,7 +262,7 @@ class OPTDecoder(nn.Module): prefix=f"{prefix}.layers", ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -274,7 +274,7 @@ class OPTDecoder(nn.Module): ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.embed_input_ids(input_ids) pos_embeds = self.embed_positions(positions) if self.project_in is not None: inputs_embeds, _ = self.project_in(inputs_embeds) @@ -311,8 +311,8 @@ class OPTModel(nn.Module): ["hidden_states"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.decoder.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.decoder.embed_input_ids(input_ids) def forward( self, @@ -394,8 +394,8 @@ class OPTForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index cbfce18b43885..859cd2cecf897 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -255,7 +255,7 @@ class OrionModel(nn.Module): config.hidden_size, ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -269,7 +269,7 @@ class OrionModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -345,8 +345,8 @@ class OrionForCausalLM(nn.Module, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/ouro.py b/vllm/model_executor/models/ouro.py index cc7947df50aea..9db6c317c26a8 100644 --- a/vllm/model_executor/models/ouro.py +++ b/vllm/model_executor/models/ouro.py @@ -361,7 +361,7 @@ class OuroModel(nn.Module): self.total_ut_steps = getattr(self.config, "total_ut_steps", 4) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -374,7 +374,7 @@ class OuroModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) for current_ut in range(self.total_ut_steps): residual = None @@ -486,8 +486,8 @@ class OuroForCausalLM(nn.Module, SupportsLoRA): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index cc6c9b4e72d76..a0fab820720fb 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -514,7 +514,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): return tuple(vision_embeddings) - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index 9a4d69dea0968..85f37cfea10b1 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -617,7 +617,7 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): return modalities - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 62994abe8e317..183f458658aa3 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -1328,10 +1328,10 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support inputs_embeds = None elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) + vision_embeddings = self.embed_multimodal(**kwargs) is_multimodal = kwargs.pop("is_multimodal", None) handle_oov_mm_token = kwargs.pop("handle_oov_mm_token", False) - inputs_embeds = self.get_input_embeddings( + inputs_embeds = self.embed_input_ids( input_ids, vision_embeddings, is_multimodal=is_multimodal, @@ -1391,7 +1391,7 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support image_embeds = self.mlp_AR(vision_outputs, image_grid_thw) return image_embeds - def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return () diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index fb0b4b2904675..ec5d0fa6226dd 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -375,7 +375,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 37a7108d5c013..3bf6a1d9763d0 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -270,7 +270,7 @@ class PersimmonModel(nn.Module): ["hidden_states"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -284,7 +284,7 @@ class PersimmonModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -347,8 +347,8 @@ class PersimmonForCausalLM(nn.Module, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index e76fb1904727c..8fee53c23fb4b 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -240,7 +240,7 @@ class PhiModel(nn.Module): ["hidden_states"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -254,7 +254,7 @@ class PhiModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -346,8 +346,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index a7b28bd18cc7a..384572217bc19 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -664,14 +664,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -679,7 +679,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self._get_text_embeddings( + inputs_embeds = self._embed_text_input_ids( input_ids, self.embed_tokens, is_multimodal=is_multimodal, @@ -691,7 +691,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) if is_multimodal is None: raise ValueError( - "`get_input_embeddings` now requires `is_multimodal` arg, " + "`embed_input_ids` now requires `is_multimodal` arg, " "please update your model runner according to " "https://github.com/vllm-project/vllm/pull/16229." ) diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index 4799b7aba7f76..0f1230a55bae6 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -1371,7 +1371,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ) return image_embeds - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index c2a3be16b6107..8425549a7bd20 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1180,7 +1180,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ) return image_embeds - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 97e5537877908..92fd858b608bc 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -482,7 +482,7 @@ class PhiMoEModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -496,7 +496,7 @@ class PhiMoEModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -648,8 +648,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index dfe5f0c52a505..8cb7d6a889da4 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -461,7 +461,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index ece1c5ec23cff..0c87f5000ff45 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -762,7 +762,7 @@ class Plamo2Model(torch.nn.Module): self.layers = Plamo2Decoder(vllm_config=vllm_config, prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -776,7 +776,7 @@ class Plamo2Model(torch.nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -839,8 +839,8 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index c99f628004fbd..50a125c3f5973 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -221,7 +221,7 @@ class QWenModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) def forward( @@ -235,7 +235,7 @@ class QWenModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index cdf32c6c51373..1bbb969ce5aa3 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -355,7 +355,7 @@ class Qwen2Model(nn.Module): self.aux_hidden_state_layers = tuple[int, ...]() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -369,7 +369,7 @@ class Qwen2Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -504,8 +504,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 8f74cab0534da..262ea771d9cdf 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -1132,7 +1132,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( return llm_positions, mrope_position_delta - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] @@ -1158,7 +1158,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( # TODO (ywang96): support overlapping modality embeddings so that # `use_audio_in_video` will work on V1. - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -1168,16 +1168,16 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ) -> torch.Tensor: # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, handle_oov_mm_token=handle_oov_mm_token, ) - def get_multimodal_embeddings_v0(self, **kwargs: object) -> NestedTensors | None: + def embed_multimodal_v0(self, **kwargs: object) -> NestedTensors | None: audio_input = self._parse_and_validate_audio_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index b0557d58d6ddd..23591480b160e 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1534,7 +1534,7 @@ class Qwen2_5_VLForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 4de6a19c1ff0c..7e883a393aa8d 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -439,7 +439,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index c03bd6a3c6d74..2ff0d19df238c 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -389,7 +389,7 @@ class Qwen2MoeModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -403,7 +403,7 @@ class Qwen2MoeModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -566,8 +566,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index c5582218b852a..eac46e0f8b055 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -73,8 +73,8 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index ff04baee91d1e..13b54bbe17488 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1507,7 +1507,7 @@ class Qwen2VLForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index f689ff79d7617..8d7f22a33fe6c 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -306,8 +306,8 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index d57b82cb02273..96751fee800bb 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -427,7 +427,7 @@ class Qwen3MoeModel(nn.Module): # Track layers for auxiliary hidden state outputs (EAGLE3) self.aux_hidden_state_layers: tuple[int, ...] = () - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -441,7 +441,7 @@ class Qwen3MoeModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -714,8 +714,8 @@ class Qwen3MoeForCausalLM( num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 9cd342caacb06..86508a7c64317 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -998,7 +998,7 @@ class Qwen3NextModel(nn.Module): else: self.norm = PPMissingLayer() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -1012,7 +1012,7 @@ class Qwen3NextModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -1217,8 +1217,8 @@ class Qwen3NextForCausalLM( # Set MoE hyperparameters self.set_moe_parameters() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index 9a552db029ee9..83694caa52480 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -93,7 +93,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module): config.hidden_size, eps=config.rms_norm_eps ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -107,7 +107,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module): ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.embed_input_ids(input_ids) assert hidden_states.shape[-1] == inputs_embeds.shape[-1] inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds) hidden_states = self.pre_fc_norm_hidden(hidden_states) @@ -257,8 +257,8 @@ class Qwen3NextMTP(nn.Module, SupportsPP, QwenNextMixtureOfExperts): ) self.set_moe_parameters() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index e6cb4442e2bef..5df2372a842cf 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -613,7 +613,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -1252,9 +1252,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object - ) -> MultiModalEmbeddings | None: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] @@ -1278,7 +1276,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( multimodal_embeddings += tuple(audio_embeddings) return multimodal_embeddings - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -1286,9 +1284,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self._get_text_embeddings( + inputs_embeds = self._embed_text_input_ids( input_ids, - self.language_model.get_input_embeddings, + self.language_model.embed_input_ids, is_multimodal=is_multimodal, handle_oov_mm_token=handle_oov_mm_token, ) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 87494c6735cd1..5f5bde1dd72d3 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1100,7 +1100,7 @@ class Qwen3LLMModel(Qwen3Model): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -1493,9 +1493,7 @@ class Qwen3VLForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object - ) -> MultiModalEmbeddings | None: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None @@ -1557,7 +1555,7 @@ class Qwen3VLForConditionalGeneration( return deepstack_input_embeds, multimodal_embeddings - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -1565,9 +1563,9 @@ class Qwen3VLForConditionalGeneration( is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self._get_text_embeddings( + inputs_embeds = self._embed_text_input_ids( input_ids, - self.language_model.get_input_embeddings, + self.language_model.embed_input_ids, is_multimodal=is_multimodal, handle_oov_mm_token=handle_oov_mm_token, ) @@ -1577,7 +1575,7 @@ class Qwen3VLForConditionalGeneration( if is_multimodal is None: raise ValueError( - "`get_input_embeddings` now requires `is_multimodal` arg, " + "`embed_input_ids` now requires `is_multimodal` arg, " "please update your model runner according to " "https://github.com/vllm-project/vllm/pull/16229." ) diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index 284b1301d07fa..5c3205faf9c2f 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -97,7 +97,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index cf74f72fe633d..6a259cade9cf1 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -777,7 +777,7 @@ class QwenVLForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.transformer - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index cfccb904f46c9..31cc645099141 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -220,8 +220,8 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.roberta.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.roberta.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py index 04da19a440a16..bf211d28f1844 100644 --- a/vllm/model_executor/models/seed_oss.py +++ b/vllm/model_executor/models/seed_oss.py @@ -334,7 +334,7 @@ class SeedOssModel(nn.Module): else: self.norm = PPMissingLayer() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -348,7 +348,7 @@ class SeedOssModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -467,8 +467,8 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 3cbdd64acc4a9..b175dd60cf650 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -595,7 +595,7 @@ class SiglipTextTransformer(nn.Module): self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.head = nn.Linear(embed_dim, config.projection_size) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings.token_embedding(input_ids) def forward( @@ -1117,7 +1117,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): def get_language_model(self) -> torch.nn.Module: return self.text_model - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -1130,16 +1130,16 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ) if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, handle_oov_mm_token=handle_oov_mm_token, ) - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 44550ae595d13..d825eb3a1c134 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -872,14 +872,14 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -892,9 +892,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 5b8bf150edf6d..4ec855f794446 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -310,7 +310,7 @@ class SolarModel(nn.Module): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -324,7 +324,7 @@ class SolarModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -478,8 +478,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index a4e309e0aa6ba..06eb7201c1a89 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module): ["hidden_states"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -260,7 +260,7 @@ class StableLMEpochModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -332,8 +332,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 4cdc90b1f5cb9..0f2942acd5006 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -249,7 +249,7 @@ class Starcoder2Model(nn.Module): ["hidden_states"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -263,7 +263,7 @@ class Starcoder2Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -333,8 +333,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 381b3f4932e55..4fff356b29e28 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -354,7 +354,7 @@ class Step3TextModel(nn.Module): ["hidden_states"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -368,7 +368,7 @@ class Step3TextModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None @@ -419,8 +419,8 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): self.model.make_empty_intermediate_tensors ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index dbb549ba3f985..5d16be1eb3128 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -1075,14 +1075,14 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -1093,9 +1093,9 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ) -> torch.Tensor: # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, @@ -1113,8 +1113,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings( + vision_embeddings = self.embed_multimodal(**kwargs) + inputs_embeds = self.embed_input_ids( input_ids, vision_embeddings, is_multimodal=input_ids == self.config.image_token_id, diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index bfa1b5bbaf84f..4d310712f303e 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -576,7 +576,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -593,8 +593,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings( + vision_embeddings = self.embed_multimodal(**kwargs) + inputs_embeds = self.embed_input_ids( input_ids, vision_embeddings, is_multimodal=input_ids == self.config.image_token_index, diff --git a/vllm/model_executor/models/teleflm.py b/vllm/model_executor/models/teleflm.py index 4dfeddb0b28e4..8a0bec9dff848 100644 --- a/vllm/model_executor/models/teleflm.py +++ b/vllm/model_executor/models/teleflm.py @@ -57,7 +57,7 @@ class TeleFLMModel(LlamaModel): if self.use_mup: self.input_mult = self.config.input_mult - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) if self.use_mup: embedding = embedding * self.input_mult diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index e799e41e2c387..19052c8d49e44 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -251,7 +251,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): self.pooler = DispatchPooler({"plugin": DummyPooler()}) - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py index eb992f7bec72b..63096e57f8eee 100644 --- a/vllm/model_executor/models/transformers/base.py +++ b/vllm/model_executor/models/transformers/base.py @@ -385,7 +385,7 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): _init_parameters(module, dtype) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: inputs_embeds = self.model.get_input_embeddings()(input_ids) if self.embed_scale is not None: inputs_embeds *= self.embed_scale @@ -416,7 +416,7 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): and input_ids is not None and inputs_embeds is None ): - inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.embed_input_ids(input_ids) input_ids = None if self.model_config.uses_mrope: diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 2efcef68d1c72..9b0463f41fa87 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -330,7 +330,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): return LanguageModel(self) - def get_multimodal_embeddings(self, **kwargs): + def embed_multimodal(self, **kwargs): pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None) image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None) # Model might use `image_patches` instead of `pixel_values` diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 95d574fb81d7a..bb0f6bd036f14 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -579,14 +579,14 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] audio_embeddings = self._process_audio_input(audio_input) return audio_embeddings - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -597,9 +597,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ) -> torch.Tensor: # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: - return super().get_input_embeddings(input_ids) + return super().embed_input_ids(input_ids) - return super().get_input_embeddings( + return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index e5ebd8138b0ac..f14b79f2886c4 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -474,7 +474,7 @@ def _merge_multimodal_embeddings( @deprecated( "`merge_multimodal_embeddings` has been replaced with " - "`SupportsMultiModal.get_input_embeddings` and will be " + "`SupportsMultiModal.embed_input_ids` and will be " "removed in v0.12." ) def merge_multimodal_embeddings( diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index cce18984b67e4..18ad8851fccda 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -399,7 +399,7 @@ class VoxtralForConditionalGeneration( return hidden_states - def get_multimodal_embeddings( + def embed_multimodal( self, **kwargs ) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None: audio_inputs = self._parse_and_validate_audio_arrays(**kwargs) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 23436a27d489d..91a10b95a08c0 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -570,7 +570,7 @@ class WhisperDecoder(nn.Module): positions: torch.Tensor, encoder_hidden_states: torch.Tensor | None, ): - inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.embed_input_ids(input_ids) positions = self.embed_positions(positions) hidden_states = inputs_embeds + positions @@ -583,7 +583,7 @@ class WhisperDecoder(nn.Module): hidden_states = self.layer_norm(hidden_states) return hidden_states - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -907,12 +907,12 @@ class WhisperForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.model.decoder - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: # Required as part of SupportsMultiModal interface. audio_input = self._parse_and_validate_audio_input(**kwargs) return [self.model.get_encoder_outputs(audio_input["input_features"])] - def get_input_embeddings( + def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, @@ -922,7 +922,7 @@ class WhisperForConditionalGeneration( ) -> torch.Tensor: # This method just returns the decoder sequence embeddings since # Whisper does not have encoder text tokens. - return self.model.decoder.get_input_embeddings(input_ids) + return self.model.decoder.embed_input_ids(input_ids) def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs: input_features = kwargs.pop("input_features", None) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index bf3107525bc53..64e6979c8fcfb 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -756,7 +756,7 @@ class Zamba2Model(nn.Module): # Final layer normalization self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings. Args: @@ -786,7 +786,7 @@ class Zamba2Model(nn.Module): """ # Handle pipeline parallelism for first rank if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.embed_input_ids(input_ids) hidden_states = inputs_embeds # Process through layers @@ -930,14 +930,14 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC # Initialize logits processing and sampling self.logits_processor = LogitsProcessor(config.vocab_size) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings. Args: input_ids: Tensor of input token IDs Returns: Embedded representation of the input tokens """ - return self.model.get_input_embeddings(input_ids) + return self.model.embed_input_ids(input_ids) def forward( self, diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 55132a6036efb..85a03efd5bb9b 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -208,7 +208,7 @@ class PromptUpdateDetails(Generic[_S]): `None` (default) means to assign embeddings to all positions of `full`. The embeddings are obtained by calling - [`SupportsMultiModal.get_multimodal_embeddings`][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings]. + [`SupportsMultiModal.embed_multimodal`][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal]. """ @staticmethod diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 55b04949ceb2a..beef5203e0394 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -279,7 +279,7 @@ class EagleProposer: if self.supports_mm_inputs: mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) - self.inputs_embeds[:num_tokens] = self.model.get_input_embeddings( + self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( self.input_ids[:num_tokens], multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed, @@ -447,9 +447,7 @@ class EagleProposer: self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states if self.supports_mm_inputs: - self.inputs_embeds[:batch_size] = self.model.get_input_embeddings( - input_ids - ) + self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids) input_ids = None inputs_embeds = self.inputs_embeds[:input_batch_size] @@ -972,9 +970,7 @@ class EagleProposer: # text-only draft models try: dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device) - self.model.get_input_embeddings( - dummy_input_ids, multimodal_embeddings=None - ) + self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None) except (NotImplementedError, AttributeError, TypeError): logger.warning( "Draft model does not support multimodal inputs, " diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 987d451fd6baf..c9c64137ca04b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1853,7 +1853,7 @@ class GPUModelRunner( ) ) - micro_batch_outputs = model.get_multimodal_embeddings( + micro_batch_outputs = model.embed_multimodal( **micro_batch_mm_inputs ) @@ -1866,7 +1866,7 @@ class GPUModelRunner( # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group) + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -2225,7 +2225,7 @@ class GPUModelRunner( # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. - inputs_embeds_scheduled = self.model.get_input_embeddings( + inputs_embeds_scheduled = self.model.embed_input_ids( self.input_ids.gpu[:num_scheduled_tokens], multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed, @@ -2261,7 +2261,7 @@ class GPUModelRunner( # Some tokens ids may need to become embeds if token_ids_idx.numel() > 0: token_ids = self.input_ids.gpu[token_ids_idx] - tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids) + tokens_to_embeds = self.model.embed_input_ids(input_ids=token_ids) self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] @@ -3889,7 +3889,7 @@ class GPUModelRunner( ) # Run multimodal encoder. - dummy_encoder_outputs = self.model.get_multimodal_embeddings( + dummy_encoder_outputs = self.model.embed_multimodal( **batched_dummy_mm_inputs ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 26816ce0f2091..0f90578671db5 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -962,7 +962,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. torch_xla.sync(wait=False) - curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group) + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) torch_xla.sync(wait=False) sanity_check_mm_encoder_outputs( @@ -1065,7 +1065,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. - inputs_embeds = self.model.get_input_embeddings( + inputs_embeds = self.model.embed_input_ids( input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed, @@ -1484,14 +1484,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) # Run multimodal encoder. torch_xla.sync(wait=False) - mm_embeds = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs - ) + mm_embeds = self.model.embed_multimodal(**batched_dummy_mm_inputs) torch_xla.sync(wait=False) num_patches = mm_embeds[0].shape[0] items_size = num_patches * num_items - # NOTE (NickLucche) pre-compile `get_input_embeddings` when mm + # NOTE (NickLucche) pre-compile `embed_input_ids` when mm # embeddings are present. We assume `--disable-mm-chunked`, # hence only whole items can be scheduled. This implies we just # need to compile when `num_items` fit the (padded) `input_ids` @@ -1519,7 +1517,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert a is None torch_xla.sync(wait=False) - # Pre-compile `get_input_embeddings` when mm_embeddings are not + # Pre-compile `embed_input_ids` when mm_embeddings are not # present. Chunk is only made of text, no mm_placeholders. for num_tokens in self.num_tokens_paddings: placeholders_ids = torch.zeros( @@ -1738,7 +1736,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # impact of recompilation until it's fixed. start = time.perf_counter() torch_xla.sync(wait=False) - dummy_encoder_outputs = self.model.get_multimodal_embeddings( + dummy_encoder_outputs = self.model.embed_multimodal( **batched_dummy_mm_inputs ) torch_xla.sync(wait=False) @@ -1974,11 +1972,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) return logits_cloned - def get_multimodal_embeddings(self, *args, **kwargs): - return self.model.get_multimodal_embeddings(*args, **kwargs) + def embed_multimodal(self, *args, **kwargs): + return self.model.embed_multimodal(*args, **kwargs) - def get_input_embeddings(self, *args, **kwargs): - return self.model.get_input_embeddings(*args, **kwargs) + def embed_input_ids(self, *args, **kwargs): + return self.model.embed_input_ids(*args, **kwargs) def prepare_structured_decoding_input( self, logits: torch.Tensor, grammar_output: "GrammarOutput" diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 0ca7e81a5c7b8..095407a8b9596 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -177,27 +177,27 @@ def sanity_check_mm_encoder_outputs( ) -> None: """ Perform sanity checks for the result of - [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][]. + [`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][]. """ assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), ( "Expected multimodal embeddings to be a list/tuple of 2D tensors, " f"or a single 3D tensor, but got {type(mm_embeddings)} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method." + "of the model's `embed_multimodal` method." ) assert len(mm_embeddings) == expected_num_items, ( "Expected number of multimodal embeddings to match number of " f"input items: {expected_num_items}, but got {len(mm_embeddings)=} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method." + "of the model's `embed_multimodal` method." ) assert all(e.ndim == 2 for e in mm_embeddings), ( "Expected multimodal embeddings to be a sequence of 2D tensors, " f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method." + "of the model's `embed_multimodal` method." )