[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Cyrus Leung 2025-09-27 16:15:12 +08:00 committed by yewentao256
parent 6970fa9937
commit 0b8166aa8f
80 changed files with 966 additions and 1139 deletions

View File

@ -66,35 +66,12 @@ Further update the model as follows:
!!! important
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
!!! note
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
??? code
```python
from .utils import merge_multimodal_embeddings
class YourModelForImage2Seq(nn.Module):
...
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
# `get_input_embeddings` should already be implemented for the language
# model as one of the requirements of basic vLLM model implementation.
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_index)
return inputs_embeds
```
You may override this method if additional logic is required for your model when merging embeddings.
- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.

View File

@ -509,9 +509,14 @@ class ModelConfig:
else: # task == "auto"
pass
else:
debug_info = {
"architectures": architectures,
"is_generative_model": is_generative_model,
"is_pooling_model": is_pooling_model,
}
raise AssertionError("The model should be a generative or "
"pooling model when task is set to "
f"{self.task!r}.")
f"{self.task!r}. Found: {debug_info}")
self.runner = runner
self.convert = convert

View File

@ -38,8 +38,7 @@ from .idefics2_vision_model import (
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, maybe_prefix,
merge_multimodal_embeddings)
is_pp_missing_parameter, maybe_prefix)
class AriaImagePixelInputs(TensorSchema):
@ -605,19 +604,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
multimodal_embeddings = self._process_image_input(image_input)
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -628,10 +614,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model(

View File

@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
class AyaVisionImagePixelInputs(TensorSchema):
@ -417,23 +416,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input, **kwargs)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_index,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -449,8 +431,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(

View File

@ -348,6 +348,9 @@ class BertModel(nn.Module, SupportsQuant):
self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -457,6 +460,9 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
prefix=maybe_prefix(prefix, "model"))
self.pooler = self._build_pooler(pooler_config)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -588,6 +594,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
),
})
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)
@ -637,6 +646,9 @@ class BertForTokenClassification(nn.Module):
Pooler.for_encode(pooler_config),
})
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)

View File

@ -426,6 +426,9 @@ class BertWithRope(nn.Module, SupportsQuant):
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(self.config) if add_pooling_layer else None
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -673,6 +676,9 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
loaded_params = loader.load_weights(weights)
return loaded_params
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.new.get_input_embeddings(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],

View File

@ -27,7 +27,7 @@ from .blip import BlipVisionModel
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
@ -631,19 +631,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -689,8 +676,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,

View File

@ -44,7 +44,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
logger = init_logger(__name__)
@ -1002,20 +1002,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self.model.get_input_embeddings(image_tokens)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.model.vocabulary_mapping.image_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -1032,8 +1018,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
image_token_id = self.model.vocabulary_mapping.image_token_id
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == image_token_id,
)
input_ids = None
hidden_states = self.model(input_ids,

View File

@ -433,6 +433,9 @@ class ChatGLMBaseModel(nn.Module):
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@ -37,8 +37,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
class Cohere2VisionImagePixelInputs(TensorSchema):
@ -430,23 +429,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input, **kwargs)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -462,8 +444,11 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None
hidden_states = self.language_model.model(

View File

@ -66,6 +66,9 @@ class DeepseekV2Model(nn.Module):
self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -205,6 +208,9 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -101,6 +101,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -142,6 +145,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
prefix=maybe_prefix(
prefix, "model"))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -41,8 +41,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
# The image token id may be various
_IMAGE_TOKEN = "<image>"
@ -346,7 +345,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
model_config = vllm_config.model_config
tokenizer = cached_tokenizer_from_config(model_config)
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN]
self.vision = self._init_vision_module(self.vision_config,
quant_config,
@ -605,19 +604,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id)
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
@ -632,8 +618,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.image_token_id,
)
input_ids = None
hidden_states = self.language_model(input_ids,

View File

@ -34,8 +34,7 @@ from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
Qwen2VLProcessingInfo)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
merge_multimodal_embeddings)
maybe_prefix)
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
@ -796,33 +795,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
@ -830,17 +813,14 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None and kwargs.get("pixel_values") is not None:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
inputs_embeds = None
else:
assert input_ids is not None
inputs_embeds = self.get_multimodal_embeddings(
input_ids,
image_input=image_input,
)
input_ids = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None
hidden_states = self.language_model(
input_ids=input_ids,

View File

@ -60,8 +60,7 @@ from vllm.sequence import IntermediateTensors
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix,
merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
from .vision import get_vit_attn_backend
logger = init_logger(__name__)
@ -1467,18 +1466,24 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
if multimodal_embeddings is None:
return inputs_embeds
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds,
multimodal_embeddings,
[self.config.im_patch_id])
return inputs_embeds
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,

View File

@ -116,6 +116,9 @@ class ErnieMultiTokenPredictor(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -160,6 +163,9 @@ class ErnieMTP(nn.Module, SupportsPP):
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -42,8 +42,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
@ -342,22 +341,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
_IMAGE_TOKEN_ID,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -373,8 +356,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
)
input_ids = None
hidden_states = self.language_model(

View File

@ -37,8 +37,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
logger = init_logger(__name__)
@ -588,22 +587,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
@ -618,8 +601,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
if (vision_embeddings is not None) and len(vision_embeddings) != 0:
kwargs = self.prepare_attn_masks(
input_ids,

View File

@ -632,8 +632,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds.
if input_ids is not None:
@ -645,15 +647,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
per_layer_inputs)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
# NOTE: this order of processing mm items is important
[self.config.image_token_id, self.config.audio_token_id])
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(self,
input_ids: torch.Tensor,

View File

@ -1552,23 +1552,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0
and all(embed.numel() > 0 for embed in multimodal_embeddings)):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id],
)
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,

View File

@ -132,6 +132,9 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -173,6 +176,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP):
prefix=maybe_prefix(
prefix, "model"))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .chatglm import ChatGLMBaseModel, ChatGLMModel
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import flatten_bn, merge_multimodal_embeddings
from .utils import flatten_bn, isin_list
class GLMVImagePixelInputs(TensorSchema):
@ -607,28 +607,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=[
self.config.boi_token_id,
self.config.pad_token_id,
self.config.eoi_token_id,
],
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -644,8 +622,15 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=isin_list(input_ids, [
self.config.boi_token_id,
self.config.pad_token_id,
self.config.eoi_token_id,
]),
)
input_ids = None
hidden_states = self.transformer(input_ids, positions,

View File

@ -52,8 +52,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .blip2 import Blip2QFormerModel
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, embed_multimodal,
init_vllm_registered_model, maybe_prefix)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
### Audio Input
@ -720,6 +719,9 @@ class GraniteSpeechForConditionalGeneration(
# Split variable length features into a tuple
return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self,
**kwargs: object,
@ -728,7 +730,7 @@ class GraniteSpeechForConditionalGeneration(
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return []
return None
audio_features = self._process_audio_input(audio_input)
return audio_features
@ -736,19 +738,21 @@ class GraniteSpeechForConditionalGeneration(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
"""Compute the merged LLM / audio embeddings."""
if multimodal_embeddings is None \
or len(multimodal_embeddings) == 0:
return self.language_model.get_input_embeddings(input_ids)
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal(
return super().get_input_embeddings(
input_ids,
self.config.audio_token_index,
self.language_model.model.get_input_embeddings,
multimodal_embeddings,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
return inputs_embeds
def forward(
self,
@ -765,7 +769,11 @@ class GraniteSpeechForConditionalGeneration(
# condition is for v0 compatibility.
elif inputs_embeds is None:
audio_embeds = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds)
inputs_embeds = self.get_input_embeddings(
input_ids,
audio_embeds,
is_multimodal=input_ids == self.config.audio_token_index,
)
input_ids = None
model_output = self.language_model(input_ids, positions,

View File

@ -989,6 +989,9 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -45,8 +45,8 @@ from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, init_vllm_registered_model, isin_list,
maybe_prefix)
from .vision import get_vision_encoder_info
EOT = "<|endofturn|>"
@ -691,7 +691,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_multimodal_embeddings(
self,
**kwargs: Unpack[HCXVisionMultimodalInputs],
) -> Optional[MultiModalEmbeddings]:
) -> MultiModalEmbeddings:
multimodal_embeddings = list()
if kwargs.get("pixel_values_images") is not None:
@ -736,26 +736,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings.append(_multimodal_embeddings_videos)
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
placeholder_token_id=[
self.config.image_token_id,
self.config.video_token_id,
],
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -771,8 +751,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility.
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=isin_list(
input_ids,
[self.config.image_token_id, self.config.video_token_id]),
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,

View File

@ -52,8 +52,7 @@ from .idefics2_vision_model import (
# yapf: enable
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
class Idefics3ImagePixelInputs(TensorSchema):
@ -539,10 +538,7 @@ class Idefics3Model(nn.Module):
return image_hidden_states
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.text_model.get_input_embeddings(input_ids)
def forward(
@ -695,22 +691,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -726,8 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None
hidden_states = self.model.text_model(input_ids,

View File

@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, MutableSequence
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable)
from typing import (TYPE_CHECKING, Callable, ClassVar, Literal, Optional,
Protocol, Union, overload, runtime_checkable)
import numpy as np
import torch
@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import supports_kw
from .interfaces_base import is_pooling_model
from .interfaces_base import VllmModel, is_pooling_model
if TYPE_CHECKING:
from vllm.config import VllmConfig
@ -90,7 +90,7 @@ class SupportsMultiModal(Protocol):
"""
...
def get_language_model(self) -> torch.nn.Module:
def get_language_model(self) -> VllmModel:
"""
Returns the underlying language model used for text generation.
@ -102,17 +102,84 @@ class SupportsMultiModal(Protocol):
"""
...
@overload
def get_input_embeddings(self, input_ids: Tensor) -> Tensor:
...
@overload
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: MultiModalEmbeddings,
*,
is_multimodal: torch.Tensor,
handle_oov_mm_token: bool = False,
) -> Tensor:
...
def _get_text_embeddings(
self,
input_ids: Tensor,
get_input_embeddings: Callable[[Tensor], Tensor],
*,
is_multimodal: Optional[Tensor],
handle_oov_mm_token: bool,
) -> Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text])
return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
return get_input_embeddings(input_ids)
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[Tensor] = None,
handle_oov_mm_token: bool = False,
) -> Tensor:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
kwargs.
Apply token embeddings to `input_ids`.
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
...
from .utils import _merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
input_ids,
self.get_language_model().get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
@runtime_checkable

View File

@ -41,6 +41,13 @@ class VllmModel(Protocol[T_co]):
) -> None:
...
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
"""Apply token embeddings to `input_ids`."""
...
def forward(
self,
input_ids: torch.Tensor,
@ -54,6 +61,19 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
return supports_kw(model_init, "vllm_config")
def _check_vllm_model_get_input_embeddings(
model: Union[type[object], object]) -> bool:
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
if not callable(model_get_input_embeddings):
logger.warning(
"The model (%s) is missing the `get_input_embeddings` method.",
model,
)
return False
return True
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
@ -88,7 +108,9 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]:
def is_vllm_model(
model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
return (_check_vllm_model_init(model)
and _check_vllm_model_get_input_embeddings(model)
and _check_vllm_model_forward(model))
@runtime_checkable

View File

@ -40,8 +40,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, isin_list, maybe_prefix)
class InternS1MultiModalProjector(nn.Module):
@ -767,24 +766,24 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
assert len(context_token_ids) >= 1
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
context_token_ids,
)
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
@ -802,9 +801,17 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=isin_list(input_ids, context_token_ids),
)
input_ids = None
forward_kwargs = {

View File

@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
isin_list, maybe_prefix)
IMG_START = '<img>'
IMG_END = '</img>'
@ -1339,24 +1339,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
assert len(context_token_ids) >= 1
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
context_token_ids,
)
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
@ -1374,9 +1374,17 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=isin_list(input_ids, context_token_ids),
)
input_ids = None
forward_kwargs = {

View File

@ -1450,24 +1450,6 @@ class BaseKeyeModule(nn.Module):
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
[
self.config.image_token_id,
self.config.video_token_id,
],
)
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,

View File

@ -66,7 +66,6 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
SupportsPP)
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
@ -424,26 +423,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
# `get_input_embeddings` should already be implemented for the language
# model as one of the requirements of basic vLLM model implementation.
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.media_placeholder_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -462,14 +441,12 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_input is None:
inputs_embeds = None
else:
inputs_embeds = self.get_input_embeddings(input_ids)
image_embeds = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
inputs_embeds = self.get_input_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.
media_placeholder_token_id,
is_multimodal=input_ids ==
self.config.media_placeholder_token_id,
)
input_ids = None

View File

@ -522,6 +522,9 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -37,9 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
Llama4ForCausalLM)
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
from .interfaces import SupportsMultiModal
from .utils import AutoWeightsLoader, maybe_prefix
logger = init_logger(__name__)
@ -79,10 +79,7 @@ class LlamaModel(nn.Module):
self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
@ -194,6 +191,11 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)
def get_language_model(self) -> torch.nn.Module:
return self.model
get_input_embeddings = SupportsMultiModal.get_input_embeddings # type: ignore
def forward(
self,
input_ids: torch.Tensor,
@ -220,20 +222,3 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
skip_prefixes=(["lm_head."]),
)
loader.load_weights(map(transform, weights))
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds

View File

@ -73,6 +73,9 @@ class LlamaModel(nn.Module):
self.config.hidden_size,
bias=False)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -149,6 +152,9 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -18,7 +18,6 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
LlamaForCausalLM)
@ -144,10 +143,7 @@ class LlamaModel(nn.Module):
eps=self.config.rms_norm_eps,
)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
@ -239,6 +235,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
requires_grad=False,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -302,11 +301,3 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
return inputs_embeds

View File

@ -41,8 +41,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
from .vision import get_vision_encoder_info
@ -676,22 +675,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -744,8 +727,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,

View File

@ -25,8 +25,8 @@ from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
LlavaDummyInputsBuilder, LlavaLikeConfig,
LlavaMultiModalProjector, init_vision_tower_for_llava)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal,
flatten_bn, init_vllm_registered_model, maybe_prefix)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix)
class LlavaNextImagePixelInputs(TensorSchema):
@ -474,19 +474,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
if multimodal_embeddings is None \
or len(multimodal_embeddings) == 0:
return self.language_model.get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal(
return super().get_input_embeddings(
input_ids,
self.config.image_token_index,
self.language_model.model.get_input_embeddings,
multimodal_embeddings,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
return inputs_embeds
def forward(
self,
@ -549,8 +551,11 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens].
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,

View File

@ -30,8 +30,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
from .vision import get_vision_encoder_info
@ -415,19 +414,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self._process_video_pixels(video_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.video_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -449,8 +435,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.video_token_index,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,

View File

@ -850,19 +850,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_index, self.config.video_token_index])
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,

View File

@ -54,8 +54,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.midashenglm import DashengConfig
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
_Tuple2 = Union[int, tuple[int, int], Sequence[int]]
@ -744,21 +743,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
return []
return self._process_audio_input(audio_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.decoder.get_input_embeddings(input_ids)
if multimodal_embeddings and len(multimodal_embeddings) > 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.audio_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -771,8 +755,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = None
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.audio_token_id,
)
input_ids = None
return self.decoder.model(input_ids,

View File

@ -117,6 +117,9 @@ class MiMoMultiTokenPredictor(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -158,6 +161,9 @@ class MiMoMTP(nn.Module):
self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -71,8 +71,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, flatten_bn, isin_list, maybe_prefix
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
@ -1144,23 +1143,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return self._process_multimodal_inputs(modalities)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert len(self.mm_token_ids) > 0
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
list(self.mm_token_ids),
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -1178,8 +1160,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=isin_list(input_ids, list(self.mm_token_ids)),
)
input_ids = None
hidden_states = self.llm.model(

View File

@ -592,10 +592,7 @@ class MiniMaxText01Model(nn.Module):
dtype=torch.long)
minimax_cache_tensors[:, slots_tensor, ...] = 0
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(self,
@ -687,10 +684,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
batch_size)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self,

View File

@ -28,7 +28,7 @@ from .llava_next import LlavaNextProcessingInfo
from .pixtral import PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
class MiniMaxVL01ImagePixelInputs(TensorSchema):
@ -218,22 +218,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
@ -403,8 +387,11 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,

View File

@ -38,8 +38,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
from .vision import get_vision_encoder_info
@ -524,22 +523,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -592,8 +575,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,

View File

@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.utils import initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -56,8 +56,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llama4 import Llama4ForCausalLM
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .vision import run_dp_sharded_vision_model
@ -813,24 +812,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -846,8 +827,11 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
# this condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
return self.language_model(input_ids, positions, intermediate_tensors,

View File

@ -43,6 +43,9 @@ class ModernBertEmbeddings(nn.Module):
eps=config.layer_norm_eps,
bias=config.norm_bias)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -220,6 +223,9 @@ class ModernBertModel(nn.Module):
eps=config.norm_eps,
bias=config.norm_bias)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
weights = self.hf_to_vllm_mapper.apply(weights)
@ -333,6 +339,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
),
})
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
self_weights = []

View File

@ -58,7 +58,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
@ -819,10 +819,7 @@ class MolmoModel(nn.Module, SupportsQuant):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
@ -1481,24 +1478,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert self.img_patch_id is not None
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.img_patch_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.LongTensor,
@ -1515,8 +1494,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.img_patch_id,
)
input_ids = None
hidden_states = self.model(input_ids,

View File

@ -35,8 +35,7 @@ from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.radio import RadioModel
from vllm.model_executor.models.utils import (flatten_bn,
init_vllm_registered_model,
maybe_prefix,
merge_multimodal_embeddings)
isin_list, maybe_prefix)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, MultiModalKwargsItems,
@ -1096,8 +1095,8 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
return modalities
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
# Validate the multimodal input keyword arguments
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if modalities is None:
@ -1121,30 +1120,6 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0):
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
assert len(context_token_ids) >= 1
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
context_token_ids,
)
return inputs_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
@ -1163,9 +1138,17 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=isin_list(input_ids, context_token_ids),
)
input_ids = None
hidden_states = self.language_model(

View File

@ -38,7 +38,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
IMG_START = '<img>'
IMG_END = '</img>'
@ -576,20 +576,24 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
context_token_ids = [self.img_context_token_id]
assert len(context_token_ids) >= 1
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
context_token_ids,
)
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
@ -608,8 +612,11 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.img_context_token_id,
)
input_ids = None
forward_kwargs = {

View File

@ -295,6 +295,9 @@ class Olmo2Model(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"],
self.config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -408,6 +411,9 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -48,7 +48,6 @@ from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings
# Cannot find the following number from hf config.
IMAGE_TOKEN = "<image>"
@ -501,19 +500,6 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
return image_features
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_pad_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -529,8 +515,11 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.image_pad_token_id,
)
input_ids = None
# up until here we have an inputs_embeds 100% numerical identity

View File

@ -585,17 +585,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
tmp = torch.concat(multimodal_embeddings, dim=0)
inputs_embeds[input_ids == self.image_pad_token_id] = tmp
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -612,8 +601,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.image_pad_token_id,
)
input_ids = None
# up until here we have a inputs_embeds 100% numerical identity

View File

@ -26,8 +26,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
from .vision import get_vision_encoder_info
logger = init_logger(__name__)
@ -362,19 +361,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
@ -388,8 +374,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,

View File

@ -51,9 +51,9 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, WeightsMapper,
_merge_multimodal_embeddings, flatten_bn,
init_vllm_registered_model, maybe_prefix)
logger = init_logger(__name__)
@ -643,14 +643,31 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id)
return inputs_embeds
inputs_embeds = self._get_text_embeddings(
input_ids,
self.embed_tokens,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def forward(self,
input_ids: torch.Tensor,
@ -666,8 +683,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
# condition is for v0 compatibility
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=self.image_token_id,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,

View File

@ -1342,12 +1342,12 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
image_attention_mask)
return image_embeds
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
return []
# The result multimodal_embeddings is tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video).
@ -1371,18 +1371,6 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,

View File

@ -1151,7 +1151,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
return None
# The result multimodal_embeddings is tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video).
@ -1175,19 +1174,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.embed_tokens(input_ids)
if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,

View File

@ -50,8 +50,7 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer,
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
try:
@ -433,22 +432,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.vision_args.image_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -465,8 +448,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.vision_args.image_token_id,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,

View File

@ -865,24 +865,26 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings += audio_embeddings
return multimodal_embeddings
# TODO (ywang96): support overlapping modality embeddings so that
# `use_audio_in_video` will work on V1.
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
# TODO (ywang96): support overlapping modality embeddings so that
# `use_audio_in_video` will work on V1.
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.config.image_token_index,
self.config.video_token_index,
self.config.audio_token_index
])
return inputs_embeds
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def get_multimodal_embeddings_v0(
self, **kwargs: object) -> Optional[NestedTensors]:

View File

@ -1365,19 +1365,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,

View File

@ -49,8 +49,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
# # === Audio Inputs === #
@ -438,19 +437,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
masked_audio_features = self._process_audio_input(audio_input)
return masked_audio_features
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -467,8 +453,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.audio_token_index,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,

View File

@ -1459,19 +1459,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,

View File

@ -79,7 +79,8 @@ from .qwen2_5_vl import (Qwen2_5_VisionAttention,
from .qwen2_vl import Qwen2VLProcessingInfo
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix, merge_multimodal_embeddings)
_merge_multimodal_embeddings, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
logger = init_logger(__name__)
@ -1324,17 +1325,22 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return multimodal_embeddings
def _compute_deepstack_embeds(
self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor:
visual_lens = [
x.shape[0] if isinstance(x, torch.Tensor) else len(x)
for x in multimodal_embeddings
]
self,
inputs_embeds: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings,
is_multimodal: torch.Tensor,
) -> tuple[torch.Tensor, MultiModalEmbeddings]:
visual_lens = [len(x) for x in multimodal_embeddings]
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501
multimodal_embeddings_cat, [self.visual_dim, self.multiscale_dim],
dim=-1)
(
multimodal_embeddings_main,
multimodal_embeddings_multiscale,
) = torch.split(
multimodal_embeddings_cat,
[self.visual_dim, self.multiscale_dim],
dim=-1,
)
multimodal_embeddings = torch.split(multimodal_embeddings_main,
visual_lens,
@ -1346,39 +1352,62 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds.size(0),
self.deepstack_num_level * inputs_embeds.size(1))
deepstack_input_embeds = merge_multimodal_embeddings(
input_ids,
deepstack_input_embeds,
multimodal_embeddings_multiscale,
placeholder_token_id=[
self.config.image_token_id, self.config.video_token_id
],
deepstack_input_embeds = _merge_multimodal_embeddings(
inputs_embeds=deepstack_input_embeds,
multimodal_embeddings=multimodal_embeddings_multiscale,
is_multimodal=is_multimodal,
)
deepstack_input_embeds = deepstack_input_embeds.view(
inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim)
deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
return deepstack_input_embeds, multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
deepstack_input_embeds = None
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if self.use_deepstack:
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
input_ids, inputs_embeds, multimodal_embeddings)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
inputs_embeds = self._get_text_embeddings(
input_ids,
self.language_model.get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
if self.use_deepstack:
if deepstack_input_embeds is None:
deepstack_input_embeds = torch.zeros_like(
inputs_embeds).unsqueeze(0).repeat(
self.deepstack_num_level, 1, 1).contiguous()
(
deepstack_input_embeds,
multimodal_embeddings,
) = self._compute_deepstack_embeds(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
else:
deepstack_input_embeds = None
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
if deepstack_input_embeds is not None:
deepstack_input_embeds = torch.zeros_like(inputs_embeds).unsqueeze(
0).repeat(self.deepstack_num_level, 1, 1).contiguous()
self._set_deepstack_input_embeds(deepstack_input_embeds)
return inputs_embeds

View File

@ -45,7 +45,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .qwen import QWenBaseModel, QWenModel
from .utils import flatten_bn, merge_multimodal_embeddings
from .utils import flatten_bn
class QwenImagePixelInputs(TensorSchema):
@ -756,21 +756,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.transformer.visual.image_pad_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -786,8 +771,12 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids ==
self.transformer.visual.image_pad_id,
)
input_ids = None
hidden_states = self.transformer(input_ids, positions,

View File

@ -218,6 +218,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.roberta.get_input_embeddings(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],

View File

@ -38,7 +38,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
IMG_START = '<img>'
IMG_END = '</img>'
@ -842,19 +842,24 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert self.img_context_token_id is not None
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.img_context_token_id,
)
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
@ -873,8 +878,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.img_context_token_id,
)
input_ids = None
forward_kwargs = {

View File

@ -483,6 +483,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -395,6 +395,9 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,

View File

@ -25,7 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
MultiModalKwargsItems)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
from .vision import run_dp_sharded_vision_model
@ -996,10 +995,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
1 else cur_feature[0])
return merged_image_features
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
@ -1007,24 +1009,21 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
if multimodal_embeddings is None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
else:
is_text = input_ids != self.config.image_token_id
text_ids = input_ids[is_text]
text_embeds = self.language_model.model.get_input_embeddings(
text_ids)
inputs_embeds = torch.empty(input_ids.shape[0],
text_embeds.shape[-1],
dtype=text_embeds.dtype,
device=text_embeds.device)
inputs_embeds[is_text] = text_embeds
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_id)
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
@ -1038,10 +1037,11 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None
hidden_states = self.language_model(input_ids,

View File

@ -40,7 +40,7 @@ from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
from .vision import VisionEncoderInfo, get_vision_encoder_info
@ -589,22 +589,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
return []
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -617,8 +601,11 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(
input_ids=input_ids,

View File

@ -233,6 +233,9 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in

View File

@ -52,8 +52,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant)
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
flatten_bn, make_empty_intermediate_tensors_factory,
maybe_prefix)
@ -797,6 +797,9 @@ class TransformersForCausalLM(TransformersBase):
else:
self.lm_head = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings()(input_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,
@ -873,13 +876,19 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
multimodal_embeds = self.get_multimodal_embeddings(**kwargs)
if multimodal_embeds is not None:
inputs_embeds = self.get_input_embeddings(
input_ids, multimodal_embeds)
input_ids,
multimodal_embeds,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None
model_output = super().forward(input_ids, positions,
intermediate_tensors, inputs_embeds)
return model_output
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(self, **kwargs):
pixel_values = kwargs.pop("pixel_values", None)
pixel_values = pixel_values if pixel_values is not None else kwargs.pop(
@ -934,15 +943,42 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings=None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0):
mask = (input_ids == self.config.image_token_id)
mask = mask.unsqueeze(-1).expand_as(inputs_embeds)
multimodal_embeddings = torch.cat(multimodal_embeddings)
"""
Apply token embeddings to `input_ids`.
inputs_embeds = inputs_embeds.masked_scatter(
mask, multimodal_embeddings)
return inputs_embeds
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens.
"""
from .utils import _merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
input_ids,
self.model.get_input_embeddings(),
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)

View File

@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
_MAX_ENCODER_BATCH_SIZE = 16
@ -555,19 +554,21 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# The audio token index is not included in the embedding table
# We need to remove it before embedding lookup
safe_input_ids = input_ids.clone()
safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0
inputs_embeds = self.language_model.get_input_embeddings(
safe_input_ids)
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index)
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(self,
input_ids: torch.Tensor,
@ -601,8 +602,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.audio_token_index,
)
input_ids = None
language_model = self.language_model

View File

@ -4,7 +4,7 @@
import itertools
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, Optional, Protocol, Union, overload
from typing import Any, Literal, Optional, Protocol, Union, overload
import torch
import torch.nn as nn
@ -391,8 +391,8 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
@ -402,63 +402,37 @@ def _merge_multimodal_embeddings(
Note:
This updates ``inputs_embeds`` in place.
"""
flattened = _flatten_embeddings(multimodal_embeddings)
try:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
flattened.to(dtype=inputs_embeds.dtype))
except RuntimeError as e:
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if len(multimodal_embeddings) == 0:
return inputs_embeds
if flattened.shape[0] != num_expected_tokens:
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try:
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
mm_embeds_flat.to(dtype=input_dtype))
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
if num_actual_tokens != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"Attempted to assign {expr} = {num_actual_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
else:
raise ValueError("Error during masked scatter operation") from e
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
def embed_multimodal(
input_ids: torch.Tensor,
multimodal_token_id: int,
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
multimodal_embeds: NestedTensors,
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal = input_ids == multimodal_token_id
is_text = ~is_multimodal
text_embeds = get_text_embeds(input_ids[is_text])
merged_embeds = torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
)
merged_embeds[is_text] = text_embeds
return _merge_multimodal_embeddings(
merged_embeds,
is_multimodal,
multimodal_embeds,
)
def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
@ -491,23 +465,29 @@ def merge_multimodal_embeddings(
This updates ``inputs_embeds`` in place.
"""
if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor(
placeholder_token_id,
pin_memory=is_pin_memory_available()).to(device=input_ids.device,
non_blocking=True)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
is_multimodal = isin_list(input_ids, placeholder_token_id)
else:
is_multimodal = (input_ids == placeholder_token_id)
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
multimodal_embeddings,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def isin_list(
elements: torch.Tensor,
test_elements_list: list[int],
) -> torch.Tensor:
test_elements = torch.tensor(
test_elements_list,
pin_memory=is_pin_memory_available(),
).to(device=elements.device, non_blocking=True)
return torch.isin(elements, test_elements)
class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module:

View File

@ -45,10 +45,8 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsTranscription)
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
logger = init_logger(__name__)
@ -376,9 +374,14 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
audio_encoder = self.tokenizer.instruct.audio_encoder
audio_tok_id = audio_encoder.audio_token
audio_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
audio_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
audio_embeddings,
is_multimodal=input_ids == audio_tok_id,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
@ -421,20 +424,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return audio_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
audio_encoder = self.tokenizer.instruct.audio_encoder
audio_tok_id = audio_encoder.audio_token
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, audio_tok_id)
return inputs_embeds
def _parse_and_validate_audio_arrays(
self, **kwargs: object) -> Union[list[torch.Tensor], None]:
audio_arrays = kwargs.pop("audio_arrays", None)

View File

@ -579,10 +579,7 @@ class WhisperDecoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
return hidden_states
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@ -916,7 +913,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens.

View File

@ -18,6 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
@ -64,8 +65,10 @@ class EagleProposer:
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.is_multimodal_model = vllm_config.model_config \
.is_multimodal_model
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
vllm_config.model_config)
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
@ -175,7 +178,8 @@ class EagleProposer:
last_token_indices: Optional[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
mm_embed_inputs: Optional[tuple[list[torch.Tensor],
torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
@ -219,18 +223,21 @@ class EagleProposer:
# copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds or None,
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.get_input_embeddings(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
@ -372,14 +379,15 @@ class EagleProposer:
self.input_ids[:batch_size] = input_ids
self._set_positions(batch_size, clamped_positions)
self.hidden_states[:batch_size] = hidden_states
if self.is_multimodal_model:
inputs_embeds = self.model.get_input_embeddings(input_ids)
self.inputs_embeds[:batch_size] = inputs_embeds
inputs_embeds = self.inputs_embeds[:input_batch_size]
if self.supports_mm_inputs:
self.inputs_embeds[:batch_size] = \
self.model.get_input_embeddings(input_ids)
input_ids = None
inputs_embeds = self.inputs_embeds[:input_batch_size]
else:
inputs_embeds = None
input_ids = self.input_ids[:input_batch_size]
inputs_embeds = None
# Run the model.
with set_forward_context(per_layer_attn_metadata,
@ -849,7 +857,7 @@ class EagleProposer:
self.attn_layer_names = list(draft_attn_layer_names)
if self.is_multimodal_model:
if self.supports_mm_inputs:
# Even if the target model is multimodal, we can also use
# text-only draft models
try:
@ -861,7 +869,7 @@ class EagleProposer:
logger.warning(
"Draft model does not support multimodal inputs, "
"falling back to text-only mode")
self.is_multimodal_model = False
self.supports_mm_inputs = False
if supports_multimodal(target_model):
# handle multimodality
@ -933,7 +941,7 @@ class EagleProposer:
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
if self.is_multimodal_model:
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:

View File

@ -368,6 +368,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
# Only relevant for multimodal models
if self.supports_mm_inputs:
self.is_mm_embed = self._make_buffer(self.max_num_tokens,
dtype=torch.bool)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy
@ -1627,9 +1632,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
shift_computed_tokens: int = 0,
) -> list[torch.Tensor]:
) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
mm_embeds = list[torch.Tensor]()
is_mm_embed = self.is_mm_embed.cpu
is_mm_embed[:total_num_scheduled_tokens] = False
req_start_idx = 0
should_sync_mrope_positions = False
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []
@ -1638,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens
for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
@ -1670,6 +1683,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
= True if is_embed is None else is_embed
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
@ -1677,6 +1694,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope:
assert req_state.mrope_positions is not None
should_sync_mrope_positions = True
mm_embeds_req, new_mrope_positions, new_delta = (
self.model.recompute_mrope_positions(
@ -1685,18 +1703,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mrope_positions=req_state.mrope_positions,
num_computed_tokens=req_state.num_computed_tokens,
))
assert req_state.mrope_positions is not None
req_state.mrope_positions.copy_(new_mrope_positions)
req_state.mrope_position_delta = new_delta
mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
if should_sync_mrope_positions:
self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(
scheduler_output.total_num_scheduled_tokens)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
return mm_embeds
return mm_embeds, is_mm_embed
def _extract_encoder_inputs(
self,
@ -1990,14 +2009,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
and not self.model_config.is_encoder_decoder):
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
inputs_embeds_scheduled = self.model.get_input_embeddings(
input_ids=self.input_ids.gpu[:num_scheduled_tokens],
multimodal_embeddings=mm_embeds or None,
self.input_ids.gpu[:num_scheduled_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
# TODO(woosuk): Avoid the copy. Optimize.
@ -2586,10 +2607,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
mm_embeds = None
if self.supports_mm_inputs:
mm_embeds = self._gather_mm_embeddings(scheduler_output,
shift_computed_tokens=1)
mm_embed_inputs = self._gather_mm_embeddings(
scheduler_output,
shift_computed_tokens=1,
)
else:
mm_embed_inputs = None
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
@ -2599,8 +2624,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
last_token_indices=token_indices_to_sample,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
mm_embeds=mm_embeds,
mm_embed_inputs=mm_embed_inputs,
)
return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None:

View File

@ -263,6 +263,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# Only relevant for multimodal models
if self.supports_mm_inputs:
self.is_mm_embed_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory)
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
# Keep in int64 to avoid overflow with long context
@ -879,13 +886,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
mm_embeds: list[torch.Tensor] = []
) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
padded_total_num_scheduled_tokens = _get_padded_token_len(
self.num_tokens_paddings, total_num_scheduled_tokens)
is_mm_embed = self.is_mm_embed_cpu
is_mm_embed[:padded_total_num_scheduled_tokens] = False
mm_embeds = list[torch.Tensor]()
req_start_idx = 0
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
# NOTE (NickLucche) here we diverge from logic in other runners, as
# we assume to only have whole mm items to process. Hence we avoid
@ -906,26 +922,53 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens,
)
assert start_idx < end_idx
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."
assert pos_info.is_embed is None, "Expected all positions to"\
" be contiguous and embeddings."
encoder_output = self.encoder_cache[mm_hash]
mm_embeds.append(encoder_output)
return mm_embeds
def _get_model_inputs(self, input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor]):
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
= True
# Only whole mm items are processed
mm_embeds.append(encoder_output)
req_start_idx += num_scheduled_tokens
is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens] \
.to(self.device)
return mm_embeds, is_mm_embed
def _get_model_inputs(
self,
input_ids: torch.Tensor,
mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]],
):
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
inputs_embeds = self.model.get_input_embeddings(
input_ids=input_ids,
input_ids,
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
return None, inputs_embeds
else:
# For text-only models, we use token ids as input.
@ -953,9 +996,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
else:
mm_embeds = []
mm_embed_inputs = None
torch_xla.sync(wait=False)
# Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution.
@ -972,7 +1016,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
end_index = self._prepare_inputs(scheduler_output, start_index)
input_ids, inputs_embeds = self._get_model_inputs(
self.input_ids, mm_embeds)
self.input_ids, mm_embed_inputs)
torch_xla.sync(wait=False)
# Run the decoder
with set_forward_context(
@ -1325,9 +1369,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hf_config.image_token_index
placeholders_ids = placeholders_ids.to(self.device)
mm_mask = torch.tensor([False] * num_tokens)
mm_mask[:items_size] = True
mm_mask = mm_mask.to(self.device)
# Assign outputs or the graph will be cut short.
a, b = self._get_model_inputs(placeholders_ids,
[mm_embeds])
a, b = self._get_model_inputs(
placeholders_ids,
mm_embed_inputs=([mm_embeds], mm_mask),
)
assert a is None
torch_xla.sync(wait=False)
@ -1338,7 +1388,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32,
device="cpu")
placeholders_ids = placeholders_ids.to(self.device)
a, b = self._get_model_inputs(placeholders_ids, [])
a, b = self._get_model_inputs(
placeholders_ids,
mm_embed_inputs=None,
)
assert a is None
torch_xla.sync(wait=False)