From 3a3b06ee706e6ff99b711b20a6c431b43e490dbc Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 11 Dec 2025 22:39:51 +0800 Subject: [PATCH] [Misc] Improve error message for `is_multimodal` (#30483) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/interfaces.py | 20 +++++++++++++++++--- vllm/model_executor/models/phi3v.py | 5 ++--- vllm/model_executor/models/qwen3_vl.py | 3 ++- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 1e5d80dd2f313..cb99d57e8b8c7 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -53,6 +53,22 @@ The output embeddings must be one of the following formats: """ +def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor: + """ + A helper function to be used in the context of + [vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids][] + to provide a better error message. + """ + if is_multimodal is None: + raise ValueError( + "`embed_input_ids` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229." + ) + + return is_multimodal + + @runtime_checkable class SupportsMultiModal(Protocol): """The interface required for all multi-modal models.""" @@ -190,12 +206,10 @@ class SupportsMultiModal(Protocol): if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds - assert is_multimodal is not None - return _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings, - is_multimodal=is_multimodal, + is_multimodal=_require_is_multimodal(is_multimodal), ) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 0d39e29dcc97b..900b0eade308c 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -64,6 +64,7 @@ from .interfaces import ( SupportsMultiModal, SupportsPP, SupportsQuant, + _require_is_multimodal, ) from .utils import ( AutoWeightsLoader, @@ -687,12 +688,10 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds - assert is_multimodal is not None - return _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings, - is_multimodal=is_multimodal, + is_multimodal=_require_is_multimodal(is_multimodal), ) def forward( diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index eac3774196a0a..f8e0ea6284994 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -93,6 +93,7 @@ from .interfaces import ( SupportsMRoPE, SupportsMultiModal, SupportsPP, + _require_is_multimodal, ) from .qwen2_5_vl import ( Qwen2_5_VisionAttention, @@ -1572,7 +1573,7 @@ class Qwen3VLForConditionalGeneration( if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds - assert is_multimodal is not None + is_multimodal = _require_is_multimodal(is_multimodal) if self.use_deepstack: (