[V1] Change return type on get_multimodal_embeddings() (#19446)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-06-16 13:32:15 -04:00 committed by GitHub
parent 387bdf0ab9
commit 90f9c2eb5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 108 additions and 103 deletions

View File

@ -601,11 +601,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
multimodal_embeddings = self._process_image_input(image_input)
return multimodal_embeddings

View File

@ -406,11 +406,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
return self._process_image_input(image_input, **kwargs)

View File

@ -627,11 +627,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

View File

@ -987,11 +987,11 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(image_input["data"].to(
self.config.torch_dtype))

View File

@ -586,11 +586,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

View File

@ -1032,11 +1032,11 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

View File

@ -324,11 +324,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
return self._process_image_input(image_input)

View File

@ -568,11 +568,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
return self._process_image_input(image_input)

View File

@ -593,11 +593,11 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
def get_language_model(self) -> torch.nn.Module:
return self.transformer
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

View File

@ -706,10 +706,11 @@ class GraniteSpeechForConditionalGeneration(
def get_multimodal_embeddings(
self,
**kwargs: object,
) -> Optional[MultiModalEmbeddings]:
) -> MultiModalEmbeddings:
"""Compute the audio embeddings if audio inputs are present."""
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return []
return None
audio_features = self._process_audio_input(audio_input)
return audio_features

View File

@ -706,11 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
return self._process_image_input(image_input)

View File

@ -44,8 +44,8 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.

View File

@ -1304,11 +1304,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
return None
# The result multimodal_embeddings is tuple of tensors, with each

View File

@ -659,11 +659,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
return self._process_image_input(image_input)

View File

@ -478,11 +478,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
@ -492,7 +492,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
if multimodal_embeddings is None:
if not multimodal_embeddings:
return self.language_model.get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal(

View File

@ -401,11 +401,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None:
return None
return []
vision_embeddings = self._process_video_pixels(video_input)
return vision_embeddings

View File

@ -839,11 +839,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
**kwargs)
if not mm_input_by_modality:
return []
return None
# The result multimodal_embeddings is tuple of tensors, with each

View File

@ -878,11 +878,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module:
return self.llm
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
return []
return self._process_multimodal_inputs(modalities)

View File

@ -318,11 +318,11 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
return self._process_image_input(image_input)

View File

@ -495,11 +495,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
vision_embeddings = self._process_image_input(image_input)

View File

@ -794,11 +794,10 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
return self._process_image_input(image_input)

View File

@ -1473,11 +1473,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
return self._process_image_input(image_input)

View File

@ -499,11 +499,11 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
return tuple(vision_embeddings)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
image_features = self._process_image_input(image_input)

View File

@ -338,11 +338,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
vision_embeddings = self._process_image_input(image_input)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)

View File

@ -655,11 +655,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
@ -669,7 +669,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id)

View File

@ -1112,11 +1112,12 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
image_attention_mask)
return image_embeds
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
return None
# The result multimodal_embeddings is tuple of tensors, with each

View File

@ -409,11 +409,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
return self._process_image_input(image_input)

View File

@ -772,13 +772,13 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
**kwargs)
if not mm_input_by_modality:
return None
return []
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).

View File

@ -1016,13 +1016,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
**kwargs)
if not mm_input_by_modality:
return None
return []
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).

View File

@ -350,11 +350,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return None
return []
masked_audio_features = self._process_audio_input(audio_input)
return masked_audio_features

View File

@ -1257,11 +1257,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
return None
# The result multimodal_embeddings is tuple of tensors, with each

View File

@ -738,11 +738,11 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
def get_language_model(self) -> torch.nn.Module:
return self.transformer
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

View File

@ -869,11 +869,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
return self._process_image_input(image_input)

View File

@ -585,11 +585,11 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
return self._process_image_input(image_input)
def get_input_embeddings(

View File

@ -546,11 +546,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return None
return []
audio_embeddings = self._process_audio_input(audio_input)
return audio_embeddings

View File

@ -687,8 +687,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def get_language_model(self) -> torch.nn.Module:
return self.model.decoder
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
# TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1.
audio_input = self._parse_and_validate_audio_input(**kwargs)

View File

@ -4,11 +4,12 @@ from typing import Optional
import torch
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
def sanity_check_mm_encoder_outputs(
mm_embeddings: object,
mm_embeddings: MultiModalEmbeddings,
expected_num_items: int,
) -> None:
"""