[Deprecation] Remove fallbacks for embed_input_ids and embed_multimodal (#30458)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-12-11 14:58:23 +08:00 committed by GitHub
parent 36c9ce2554
commit 979f50efd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 9 additions and 68 deletions

View File

@ -111,13 +111,7 @@ class SupportsMultiModal(Protocol):
the appearances of their corresponding multimodal data item in the the appearances of their corresponding multimodal data item in the
input prompt. input prompt.
""" """
if hasattr(self, "get_multimodal_embeddings"): ...
logger.warning_once(
"`get_multimodal_embeddings` for vLLM models is deprecated and will be "
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
"this method to `embed_multimodal`."
)
return self.get_multimodal_embeddings(**kwargs)
def get_language_model(self) -> VllmModel: def get_language_model(self) -> VllmModel:
""" """
@ -196,12 +190,7 @@ class SupportsMultiModal(Protocol):
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds return inputs_embeds
if is_multimodal is None: assert is_multimodal is not None
raise ValueError(
"`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
return _merge_multimodal_embeddings( return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,

View File

@ -68,15 +68,6 @@ def _check_vllm_model_init(model: type[object] | object) -> bool:
def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool: def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool:
model_embed_input_ids = getattr(model, "embed_input_ids", None) model_embed_input_ids = getattr(model, "embed_input_ids", None)
if not callable(model_embed_input_ids): if not callable(model_embed_input_ids):
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
if callable(model_get_input_embeddings):
logger.warning(
"`get_input_embeddings` for vLLM models is deprecated and will be "
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
"this method to `embed_input_ids`."
)
model.embed_input_ids = model_get_input_embeddings
return True
logger.warning( logger.warning(
"The model (%s) is missing the `embed_input_ids` method.", "The model (%s) is missing the `embed_input_ids` method.",
model, model,

View File

@ -18,15 +18,10 @@ from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV2Model, DeepseekV2Model,
) )
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM
from vllm.multimodal.inputs import NestedTensors
from .utils import ( from .interfaces import SupportsMultiModal
_merge_multimodal_embeddings, from .utils import make_empty_intermediate_tensors_factory, maybe_prefix
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -117,26 +112,10 @@ class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
) )
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings( def get_language_model(self) -> torch.nn.Module:
self, return self.model
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = super().embed_input_ids(input_ids)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: embed_input_ids = SupportsMultiModal.embed_input_ids # type: ignore
return inputs_embeds
assert is_multimodal is not None
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def forward( def forward(
self, self,
@ -155,11 +134,3 @@ class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
"model.embed_tokens.weight", "model.embed_tokens.weight",
"lm_head.weight", "lm_head.weight",
} }
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: NestedTensors | None = None,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

View File

@ -687,12 +687,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds return inputs_embeds
if is_multimodal is None: assert is_multimodal is not None
raise ValueError(
"`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
return _merge_multimodal_embeddings( return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,

View File

@ -1572,12 +1572,7 @@ class Qwen3VLForConditionalGeneration(
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds return inputs_embeds
if is_multimodal is None: assert is_multimodal is not None
raise ValueError(
"`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
if self.use_deepstack: if self.use_deepstack:
( (