[V0 Deprecation] Remove V0 logic from get_input_embeddings interface (#25242)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-09-19 19:10:52 +08:00 committed by GitHub
parent a3d087adec
commit 5089fd749c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 83 deletions

View File

@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
EOT = "<|endofturn|>" EOT = "<|endofturn|>"
@ -740,33 +741,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if (kwargs.get("pixel_values_images") is not None if multimodal_embeddings is not None \
or kwargs.get("pixel_values_videos") and len(multimodal_embeddings) != 0:
is not None): # v0 compatibility inputs_embeds = merge_multimodal_embeddings(
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) input_ids,
if multimodal_embeddings is not None: inputs_embeds,
multimodal_embeddings = torch.cat(multimodal_embeddings, dim=0) multimodal_embeddings,
_mask_image = input_ids == self.config.image_token_id placeholder_token_id=[
_mask_video = input_ids == self.config.video_token_id self.config.image_token_id,
assert _mask_image.sum() + _mask_video.sum() == len( self.config.video_token_id,
multimodal_embeddings) ],
)
if multimodal_embeddings.dtype != inputs_embeds.dtype:
multimodal_embeddings = multimodal_embeddings.to(
dtype=inputs_embeds.dtype)
if multimodal_embeddings.device != inputs_embeds.device:
multimodal_embeddings = multimodal_embeddings.to(
device=inputs_embeds.device)
if _mask_image.sum() > 0:
inputs_embeds[
_mask_image] = multimodal_embeddings[:sum(_mask_image)]
if _mask_video.sum() > 0:
inputs_embeds[_mask_video] = multimodal_embeddings[
-sum(_mask_video):]
return inputs_embeds return inputs_embeds
def forward( def forward(
@ -783,8 +771,9 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids=input_ids, multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
**kwargs) inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,

View File

@ -23,7 +23,6 @@ from vllm.utils import supports_kw
from .interfaces_base import is_pooling_model from .interfaces_base import is_pooling_model
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -97,33 +96,10 @@ class SupportsMultiModal(Protocol):
""" """
... ...
# Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated
@overload
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: Tensor, input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
) -> Tensor:
...
# TODO: Remove this overload once v0 is deprecated
@overload
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> Tensor:
...
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
# Only necessary so that the v0 overload is valid
# TODO: Remove attn_metadata once v0 is deprecated
attn_metadata: Optional["AttentionMetadata"] = None,
) -> Tensor: ) -> Tensor:
""" """
Returns the input embeddings merged from the text embeddings from Returns the input embeddings merged from the text embeddings from

View File

@ -13,9 +13,7 @@ from transformers import BatchFeature, ProcessorMixin
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.model_loader import DefaultModelLoader
@ -37,8 +35,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings, merge_multimodal_embeddings)
merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
_MAX_ENCODER_BATCH_SIZE = 16 _MAX_ENCODER_BATCH_SIZE = 16
@ -568,17 +565,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
safe_input_ids) safe_input_ids)
if multimodal_embeddings is not None and len( if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0: multimodal_embeddings) > 0:
inputs_embeds = merge_multimodal_embeddings(
# TODO(ywang96): remove this block after v0 is deprecated. input_ids, inputs_embeds, multimodal_embeddings,
if not envs.VLLM_USE_V1: self.config.audio_token_index)
attn_metadata = get_forward_context().attn_metadata
merge_multimodal_embeddings_from_map(
inputs_embeds, multimodal_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"])
else:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index)
return inputs_embeds return inputs_embeds
def forward(self, def forward(self,

View File

@ -15,7 +15,7 @@ import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available, from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available) is_uva_available)
@ -389,22 +389,6 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
_embedding_count_expression(inner) for inner in embeddings) _embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings_from_map(
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
placeholder map .
Note:
This updates ``inputs_embeds`` in place.
"""
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
placeholder_map.src].to(dtype=inputs_embeds.dtype)
return inputs_embeds
def _merge_multimodal_embeddings( def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor, is_multimodal: torch.Tensor,