mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 00:55:01 +08:00
[Deprecation] Remove fallbacks for embed_input_ids and embed_multimodal (#30458)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
36c9ce2554
commit
979f50efd0
@ -111,13 +111,7 @@ class SupportsMultiModal(Protocol):
|
||||
the appearances of their corresponding multimodal data item in the
|
||||
input prompt.
|
||||
"""
|
||||
if hasattr(self, "get_multimodal_embeddings"):
|
||||
logger.warning_once(
|
||||
"`get_multimodal_embeddings` for vLLM models is deprecated and will be "
|
||||
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
|
||||
"this method to `embed_multimodal`."
|
||||
)
|
||||
return self.get_multimodal_embeddings(**kwargs)
|
||||
...
|
||||
|
||||
def get_language_model(self) -> VllmModel:
|
||||
"""
|
||||
@ -196,12 +190,7 @@ class SupportsMultiModal(Protocol):
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`embed_input_ids` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229."
|
||||
)
|
||||
assert is_multimodal is not None
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
||||
@ -68,15 +68,6 @@ def _check_vllm_model_init(model: type[object] | object) -> bool:
|
||||
def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool:
|
||||
model_embed_input_ids = getattr(model, "embed_input_ids", None)
|
||||
if not callable(model_embed_input_ids):
|
||||
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||
if callable(model_get_input_embeddings):
|
||||
logger.warning(
|
||||
"`get_input_embeddings` for vLLM models is deprecated and will be "
|
||||
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
|
||||
"this method to `embed_input_ids`."
|
||||
)
|
||||
model.embed_input_ids = model_get_input_embeddings
|
||||
return True
|
||||
logger.warning(
|
||||
"The model (%s) is missing the `embed_input_ids` method.",
|
||||
model,
|
||||
|
||||
@ -18,15 +18,10 @@ from vllm.model_executor.models.deepseek_v2 import (
|
||||
DeepseekV2DecoderLayer,
|
||||
DeepseekV2Model,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
|
||||
from .utils import (
|
||||
_merge_multimodal_embeddings,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .utils import make_empty_intermediate_tensors_factory, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -117,26 +112,10 @@ class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
|
||||
)
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = super().embed_input_ids(input_ids)
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
assert is_multimodal is not None
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
embed_input_ids = SupportsMultiModal.embed_input_ids # type: ignore
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -155,11 +134,3 @@ class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
|
||||
"model.embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
}
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: NestedTensors | None = None,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
@ -687,12 +687,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`embed_input_ids` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229."
|
||||
)
|
||||
assert is_multimodal is not None
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
||||
@ -1572,12 +1572,7 @@ class Qwen3VLForConditionalGeneration(
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`embed_input_ids` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229."
|
||||
)
|
||||
assert is_multimodal is not None
|
||||
|
||||
if self.use_deepstack:
|
||||
(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user