mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[V0 Deprecation] Remove V0 logic from get_input_embeddings interface (#25242)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
a3d087adec
commit
5089fd749c
@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||||
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
from .vision import get_vision_encoder_info
|
from .vision import get_vision_encoder_info
|
||||||
|
|
||||||
EOT = "<|endofturn|>"
|
EOT = "<|endofturn|>"
|
||||||
@ -740,33 +741,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if (kwargs.get("pixel_values_images") is not None
|
if multimodal_embeddings is not None \
|
||||||
or kwargs.get("pixel_values_videos")
|
and len(multimodal_embeddings) != 0:
|
||||||
is not None): # v0 compatibility
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
input_ids,
|
||||||
if multimodal_embeddings is not None:
|
inputs_embeds,
|
||||||
multimodal_embeddings = torch.cat(multimodal_embeddings, dim=0)
|
multimodal_embeddings,
|
||||||
_mask_image = input_ids == self.config.image_token_id
|
placeholder_token_id=[
|
||||||
_mask_video = input_ids == self.config.video_token_id
|
self.config.image_token_id,
|
||||||
assert _mask_image.sum() + _mask_video.sum() == len(
|
self.config.video_token_id,
|
||||||
multimodal_embeddings)
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if multimodal_embeddings.dtype != inputs_embeds.dtype:
|
|
||||||
multimodal_embeddings = multimodal_embeddings.to(
|
|
||||||
dtype=inputs_embeds.dtype)
|
|
||||||
if multimodal_embeddings.device != inputs_embeds.device:
|
|
||||||
multimodal_embeddings = multimodal_embeddings.to(
|
|
||||||
device=inputs_embeds.device)
|
|
||||||
|
|
||||||
if _mask_image.sum() > 0:
|
|
||||||
inputs_embeds[
|
|
||||||
_mask_image] = multimodal_embeddings[:sum(_mask_image)]
|
|
||||||
if _mask_video.sum() > 0:
|
|
||||||
inputs_embeds[_mask_video] = multimodal_embeddings[
|
|
||||||
-sum(_mask_video):]
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -783,8 +771,9 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids=input_ids,
|
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
**kwargs)
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
|
multimodal_embeddings)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
|
|||||||
@ -23,7 +23,6 @@ from vllm.utils import supports_kw
|
|||||||
from .interfaces_base import is_pooling_model
|
from .interfaces_base import is_pooling_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -97,33 +96,10 @@ class SupportsMultiModal(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
# Only for models that support v0 chunked prefill
|
|
||||||
# TODO(ywang96): Remove this overload once v0 is deprecated
|
|
||||||
@overload
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
attn_metadata: Optional["AttentionMetadata"] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
...
|
|
||||||
|
|
||||||
# TODO: Remove this overload once v0 is deprecated
|
|
||||||
@overload
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
# Only necessary so that the v0 overload is valid
|
|
||||||
# TODO: Remove attn_metadata once v0 is deprecated
|
|
||||||
attn_metadata: Optional["AttentionMetadata"] = None,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the input embeddings merged from the text embeddings from
|
Returns the input embeddings merged from the text embeddings from
|
||||||
|
|||||||
@ -13,9 +13,7 @@ from transformers import BatchFeature, ProcessorMixin
|
|||||||
from transformers.models.whisper import WhisperFeatureExtractor
|
from transformers.models.whisper import WhisperFeatureExtractor
|
||||||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
|
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.model_loader import DefaultModelLoader
|
from vllm.model_executor.model_loader import DefaultModelLoader
|
||||||
@ -37,8 +35,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
|||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings,
|
merge_multimodal_embeddings)
|
||||||
merge_multimodal_embeddings_from_map)
|
|
||||||
|
|
||||||
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
|
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
|
||||||
_MAX_ENCODER_BATCH_SIZE = 16
|
_MAX_ENCODER_BATCH_SIZE = 16
|
||||||
@ -568,17 +565,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
|||||||
safe_input_ids)
|
safe_input_ids)
|
||||||
if multimodal_embeddings is not None and len(
|
if multimodal_embeddings is not None and len(
|
||||||
multimodal_embeddings) > 0:
|
multimodal_embeddings) > 0:
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
# TODO(ywang96): remove this block after v0 is deprecated.
|
input_ids, inputs_embeds, multimodal_embeddings,
|
||||||
if not envs.VLLM_USE_V1:
|
self.config.audio_token_index)
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
merge_multimodal_embeddings_from_map(
|
|
||||||
inputs_embeds, multimodal_embeddings,
|
|
||||||
attn_metadata.multi_modal_placeholder_index_maps["audio"])
|
|
||||||
else:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
|
||||||
self.config.audio_token_index)
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
|
|||||||
@ -15,7 +15,7 @@ import vllm.envs as envs
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
|
from vllm.multimodal import NestedTensors
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
|
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
|
||||||
is_uva_available)
|
is_uva_available)
|
||||||
@ -389,22 +389,6 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
|
|||||||
_embedding_count_expression(inner) for inner in embeddings)
|
_embedding_count_expression(inner) for inner in embeddings)
|
||||||
|
|
||||||
|
|
||||||
def merge_multimodal_embeddings_from_map(
|
|
||||||
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
|
|
||||||
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
|
|
||||||
placeholder map .
|
|
||||||
|
|
||||||
Note:
|
|
||||||
This updates ``inputs_embeds`` in place.
|
|
||||||
"""
|
|
||||||
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
|
|
||||||
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
|
|
||||||
placeholder_map.src].to(dtype=inputs_embeds.dtype)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_multimodal_embeddings(
|
def _merge_multimodal_embeddings(
|
||||||
inputs_embeds: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
is_multimodal: torch.Tensor,
|
is_multimodal: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user