diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index bb4177dfc457..b69c7b6a9376 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -601,11 +601,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] multimodal_embeddings = self._process_image_input(image_input) return multimodal_embeddings diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 7e15e57a4d03..6a95ac089ff4 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -406,11 +406,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] return self._process_image_input(image_input, **kwargs) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 279541bed55a..87fc6b5b0240 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -627,11 +627,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index aea44261dd69..21f29dc43c26 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -987,11 +987,11 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] assert self.model.vqmodel is not None image_tokens = self.model.get_image_tokens(image_input["data"].to( self.config.torch_dtype)) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index d8c01f83eded..6341c65a5d4c 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -586,11 +586,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 47760aabb959..4b220ea483e8 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -1032,11 +1032,11 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index cb141dbc5aa3..462f85c3dd62 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -324,11 +324,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] return self._process_image_input(image_input) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 18cb6ea68d1a..b633c0003c63 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -568,11 +568,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] return self._process_image_input(image_input) diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 034c7654f4d9..e9271367a472 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -593,11 +593,11 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, def get_language_model(self) -> torch.nn.Module: return self.transformer - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index 831164ba88a4..137aad926cb9 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -706,10 +706,11 @@ class GraniteSpeechForConditionalGeneration( def get_multimodal_embeddings( self, **kwargs: object, - ) -> Optional[MultiModalEmbeddings]: + ) -> MultiModalEmbeddings: """Compute the audio embeddings if audio inputs are present.""" audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: + return [] return None audio_features = self._process_audio_input(audio_input) return audio_features diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index de8596282ca9..be04ad0422df 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -706,11 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] return self._process_image_input(image_input) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index cb2a4062b84c..0e7e4e73eca9 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -44,8 +44,8 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: """ Returns multimodal embeddings generated from multimodal kwargs to be merged with text embeddings. diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 0c61369c5f51..9d5cceccff2f 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -1304,11 +1304,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: + return [] return None # The result multimodal_embeddings is tuple of tensors, with each diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 725e1b2c1948..7dea260a58e0 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -659,11 +659,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] return self._process_image_input(image_input) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 6f5f231875de..60ede454ff27 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -478,11 +478,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings @@ -492,7 +492,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: - if multimodal_embeddings is None: + if not multimodal_embeddings: return self.language_model.get_input_embeddings(input_ids) inputs_embeds = embed_multimodal( diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index a3406d090db8..78084465e7a2 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -401,11 +401,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: video_input = self._parse_and_validate_video_input(**kwargs) if video_input is None: - return None + return [] vision_embeddings = self._process_video_pixels(video_input) return vision_embeddings diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index d90d3d4a0960..265f63d7bd29 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -839,11 +839,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: mm_input_by_modality = self._parse_and_validate_multimodal_inputs( **kwargs) if not mm_input_by_modality: + return [] return None # The result multimodal_embeddings is tuple of tensors, with each diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 4100fee0ec84..b923287dca3e 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -878,11 +878,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.llm - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: - return None + return [] return self._process_multimodal_inputs(modalities) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index b2ededcaf67c..bc00af2ec6b9 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -318,11 +318,11 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, raise AssertionError("This line should be unreachable.") - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] return self._process_image_input(image_input) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 9147240b2b2a..59deacffd285 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -495,11 +495,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 54fae279d531..bf4bd309eea2 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -794,11 +794,10 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] return self._process_image_input(image_input) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 1fa76b9ac7af..70c60c6d528b 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1473,11 +1473,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] return self._process_image_input(image_input) diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 770e08aa2a5f..900a1f5de458 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -499,11 +499,11 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): return tuple(vision_embeddings) - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] image_features = self._process_image_input(image_input) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index a0e2912578c5..cc2cebe4a4a3 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -338,11 +338,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 376c53d2cb99..9cec7831ae0c 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -655,11 +655,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings @@ -669,7 +669,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) - if multimodal_embeddings is not None: + if multimodal_embeddings: inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, self.image_token_id) diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 924e6436897d..a3ca72d1f5cf 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1112,11 +1112,12 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): image_attention_mask) return image_embeds - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: + return [] return None # The result multimodal_embeddings is tuple of tensors, with each diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 705586b6a6ea..320c0e10d06a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -409,11 +409,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] return self._process_image_input(image_input) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 7172394e4200..ad1e8fcb39d5 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -772,13 +772,13 @@ class Qwen2_5OmniThinkerForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: mm_input_by_modality = self._parse_and_validate_multimodal_inputs( **kwargs) if not mm_input_by_modality: - return None + return [] # The result multimodal_embeddings is tuple of tensors, with each # tensor correspoending to a multimodal data item (image or video). diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 73d241921bcf..202cd5e860d1 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1016,13 +1016,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: mm_input_by_modality = self._parse_and_validate_multimodal_inputs( **kwargs) if not mm_input_by_modality: - return None + return [] # The result multimodal_embeddings is tuple of tensors, with each # tensor correspoending to a multimodal data item (image or video). diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 6951630c6f23..e77a8e05d200 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -350,11 +350,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: - return None + return [] masked_audio_features = self._process_audio_input(audio_input) return masked_audio_features diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d8318fff868e..49b709069cd2 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1257,11 +1257,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: + return [] return None # The result multimodal_embeddings is tuple of tensors, with each diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index e828ce9c9849..546737621a7c 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -738,11 +738,11 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, def get_language_model(self) -> torch.nn.Module: return self.transformer - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 08c47facad97..9fba24ac5cec 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -869,11 +869,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] return self._process_image_input(image_input) diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 5aa3ddabc19e..2645e700fcda 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -585,11 +585,11 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] return self._process_image_input(image_input) def get_input_embeddings( diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 43836f2956c3..f6b9d19694ef 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -546,11 +546,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: - return None + return [] audio_embeddings = self._process_audio_input(audio_input) return audio_embeddings diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 3ee5f7dba01f..8cf2a009d667 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -687,8 +687,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, def get_language_model(self) -> torch.nn.Module: return self.model.decoder - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: # TODO: This method does not obey the interface for SupportsMultiModal. # Refactor this once encoder/decoder support is implemented in V1. audio_input = self._parse_and_validate_audio_input(**kwargs) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 055cf01530f0..70339ff2f005 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -4,11 +4,12 @@ from typing import Optional import torch +from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.v1.kv_cache_interface import KVCacheGroupSpec def sanity_check_mm_encoder_outputs( - mm_embeddings: object, + mm_embeddings: MultiModalEmbeddings, expected_num_items: int, ) -> None: """