[V1] Refactor model executable interface for multimodal models (#10570)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2024-11-26 12:46:11 -08:00 committed by GitHub
parent 7576cd38df
commit 2f0a0a17a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 568 additions and 293 deletions

View File

@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData
@ -609,6 +610,25 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_projection(query_output)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
BLIP2_IMAGE_TOKEN_ID)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -616,6 +636,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
"""Run forward pass for BLIP-2.
@ -648,32 +669,24 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
See also:
:class:`Blip2ImageInputs`
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states

View File

@ -29,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
@ -38,7 +39,7 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
maybe_prefix, merge_multimodal_embeddings)
# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
@ -987,6 +988,29 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data=self._validate_pixel_values(pixel_values),
)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(image_input["data"].to(
self.config.torch_dtype))
vision_embeddings = self.model.get_input_embeddings(image_tokens)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.model.vocabulary_mapping.image_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -994,27 +1018,27 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(
image_input["data"].to(self.config.torch_dtype))
image_token_id = self.model.vocabulary_mapping.image_token_id
special_image_mask = input_ids == image_token_id
image_tokens = image_tokens.to(input_ids.device,
input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask,
image_tokens)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
hidden_states = self.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(

View File

@ -33,7 +33,8 @@ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
@ -545,6 +546,30 @@ class ChatGLMModel(nn.Module):
""")
return GLMImagePixelInputs(pixel_values=pixel_values)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input["pixel_values"] is None:
return None
pixel_values = image_input["pixel_values"].to(
dtype=self.config.torch_dtype)
vision_embeddings = self.vision(pixel_values)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_glm_vision_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
vision_embeddings=multimodal_embeddings,
boi_token_id=self.config.boi_token_id,
eoi_token_id=self.config.eoi_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -552,26 +577,17 @@ class ChatGLMModel(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is None:
inputs_embeds = self.embedding(input_ids)
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input["pixel_values"] is not None:
pixel_values = image_input["pixel_values"].to(
dtype=inputs_embeds.dtype)
image_embeds = self.vision(pixel_values)
boi_token_id = self.config.boi_token_id
eoi_token_id = self.config.eoi_token_id
inputs_embeds = merge_glm_vision_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
vision_embeddings=image_embeds,
boi_token_id=boi_token_id,
eoi_token_id=eoi_token_id)
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
if intermediate_tensors is None and inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
else:
inputs_embeds = intermediate_tensors["hidden_states"]

View File

@ -35,6 +35,7 @@ from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
@ -302,6 +303,25 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -309,24 +329,19 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model(
input_ids=input_ids,

View File

@ -2,7 +2,7 @@ from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)
import torch
from typing_extensions import TypeIs
from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger
from vllm.utils import supports_kw
@ -10,10 +10,14 @@ from vllm.utils import supports_kw
from .interfaces_base import is_embedding_model
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.multimodal.inputs import NestedTensors # noqa: F401
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
T = TypeVar("T", default="NestedTensors")
@runtime_checkable
class SupportsMultiModal(Protocol):
@ -28,6 +32,36 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""
def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
"""
...
# Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated
@overload
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[T] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
) -> torch.Tensor:
...
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[T] = None,
) -> torch.Tensor:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
kwargs.
"""
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead

View File

@ -26,6 +26,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@ -641,6 +642,26 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
visual_token_mask = None
return visual_token_mask
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
assert self.img_context_token_id is not None
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.img_context_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -648,26 +669,22 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
visual_token_mask = None
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
visual_token_mask = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
visual_token_mask = self._get_visual_token_mask(input_ids)
input_ids = None
else:
inputs_embeds = None
visual_token_mask = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
forward_kwargs = {
"input_ids": input_ids,
@ -677,6 +694,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
if self.img_context_token_id is not None:
visual_token_mask = self._get_visual_token_mask(input_ids)
# We always overwrite it back to None after computing visual token
# mask so that this doesn't need to depend on encoder output
self.img_context_token_id = None
if self.is_mono:
forward_kwargs.update({"visual_token_mask": visual_token_mask})

View File

@ -478,7 +478,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
def process_mm_inputs(self, **kwargs):
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
@ -488,12 +488,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def get_input_embeddings(
self,
input_ids: torch.Tensor,
vision_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if vision_embeddings is not None:
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
return inputs_embeds
@ -544,10 +544,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
"""
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None

View File

@ -19,6 +19,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_list_of
@ -565,6 +566,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings)
]
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
if multimodal_embeddings is None:
return self.language_model.get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal(
input_ids,
self.config.image_token_index,
self.language_model.model.get_input_embeddings,
multimodal_embeddings,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -572,6 +597,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-NeXT.
@ -620,24 +646,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
"""
if intermediate_tensors is not None:
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = embed_multimodal(
input_ids,
self.config.image_token_index,
self.language_model.model.get_input_embeddings,
lambda _: self._process_image_input(image_input),
)
else:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,
@ -645,7 +661,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(

View File

@ -18,6 +18,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors
@ -388,6 +389,25 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None:
return None
vision_embeddings = self._process_video_pixels(video_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.video_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -395,6 +415,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-NeXT-Video.
@ -404,22 +425,15 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values_videos: Pixels in each frames for each input videos.
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is not None:
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = self.language_model \
.model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,

View File

@ -21,6 +21,7 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors
@ -824,6 +825,49 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature
def get_multimodal_embeddings(
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
# We make a tuple of each embedding with its modality string. This is a
# temporary workaround for models to handle mixed modalities when
# get_multimodal_embeddings and get_input_embeddings are called
# separately.
# TODO(ywang96): Add support for mixed-modality inference for v1.
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings.append((vision_embeddings, "image"))
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
multimodal_embeddings.append((video_embeddings, "video"))
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[List[Tuple[NestedTensors,
str]]] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
for embeddings, modality in multimodal_embeddings:
if modality == "image":
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, embeddings,
self.config.image_token_index)
if modality == "video":
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, embeddings,
self.config.video_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -831,6 +875,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-Onevision.
@ -840,28 +885,15 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values_videos: Pixels in each frames for each input videos.
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if modalities:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,

View File

@ -3,7 +3,7 @@ import re
from array import array
from dataclasses import dataclass
from functools import lru_cache, partial
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict
import torch
from einops import rearrange
@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
@ -756,6 +757,12 @@ class MolmoModel(nn.Module):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@ -1098,19 +1105,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return image_features
def _merge_multimodal_embeddings(
self,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
image_input_idx: torch.Tensor,
seq_len: Union[torch.Tensor, List[torch.Tensor]],
) -> torch.Tensor:
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input)
image_input_idx = image_input["image_input_idx"]
seq_len = image_input["seq_len"]
batch_size, num_image, num_patch = image_features.shape[:3]
assert image_input_idx.shape == (batch_size, num_image, num_patch)
image_features = image_features.to(inputs_embeds.device)
seq_len = seq_len.to(inputs_embeds.device)
# insert the image feature into the embedding.
image_features = image_features.view(batch_size, num_image * num_patch,
-1)
@ -1130,12 +1134,24 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_input_idx = image_input_idx + offset.to(image_input_idx.dtype)
image_input_idx = image_input_idx.flatten()[:, None]
mat = image_input_idx == torch.arange(
seq_len.sum().item(), device=inputs_embeds.device)[None, :]
seq_len.sum().item(), device=image_features.device)[None, :]
mat = mat.to(image_features.dtype)
inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md',
image_features, mat)
# Note: In this original implementation from AI2, the final
# vision_embeddings will be always be the same length
# of input embedddings, which is not very efficient.
# TODO(ywang96): see if this can be optimized.
vision_embeddings = torch.einsum('nd,nm->md', image_features, mat)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = inputs_embeds + multimodal_embeddings
return inputs_embeds
def forward(
@ -1145,39 +1161,27 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> SamplerOutput:
if intermediate_tensors is not None:
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.model.embed_tokens(input_ids)
image_features = self._process_image_input(image_input)
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
inputs_embeds = self._merge_multimodal_embeddings(
inputs_embeds,
image_features,
image_input["image_input_idx"],
image_input["seq_len"],
)
else:
inputs_embeds = self.model.embed_tokens(input_ids)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
hidden_states = self.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states

View File

@ -13,6 +13,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
@ -240,36 +241,45 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.multi_modal_projector(image_features)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
if parsed_image_input is not None:
vision_embeddings = self._process_image_input(
parsed_image_input)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
vision_embeddings = vision_embeddings * (
self.config.hidden_size**-0.5)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
input_ids = None
else:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,

View File

@ -676,7 +676,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return image_embeds
def process_mm_inputs(self, **kwargs):
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
@ -686,12 +686,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_input_embeddings(
self,
input_ids: torch.Tensor,
vision_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
if vision_embeddings is not None:
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id)
return inputs_embeds
@ -703,12 +703,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object):
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None

View File

@ -42,10 +42,12 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
@ -371,6 +373,25 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
return masked_audio_features
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return None
masked_audio_features = self._process_audio_input(audio_input)
return masked_audio_features
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -378,33 +399,27 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
inputs_embeds = None
else:
inputs_embeds = self.language_model.embed_tokens(input_ids)
masked_audio_features = self._process_audio_input(audio_input)
# merge llm embeddings and audio features
mask = (input_ids == self.config.audio_token_index)
inputs_embeds[mask, :] = masked_audio_features
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
input_ids = None
input_ids = None
hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,

View File

@ -63,7 +63,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
MultiModalKwargs)
MultiModalKwargs, NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
@ -1238,6 +1238,55 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds[mask, :] = multimodal_embeddings
return inputs_embeds
def get_multimodal_embeddings(
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
return None
# We make a tuple of each embedding with its modality string. This is a
# temporary workaround for models to handle mixed modalities when
# get_multimodal_embeddings and get_input_embeddings are called
# separately.
# TODO(ywang96): Add support for mixed-modality inference for v1.
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
if image_input is not None:
image_embeds = self._process_image_input(image_input)
multimodal_embeddings.append((image_embeds, "image"))
if video_input is not None:
video_embeds = self._process_video_input(video_input)
multimodal_embeddings.append((video_embeds, "video"))
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[List[Tuple[NestedTensors,
str]]] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
for embeddings, modality in multimodal_embeddings:
if modality == "image":
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
embeddings,
placeholder_token_id=self.config.image_token_id,
)
if modality == "video":
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
embeddings,
placeholder_token_id=self.config.video_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -1245,6 +1294,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Qwen2-VL.
@ -1266,42 +1316,26 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
inputs_embeds = None
else:
if uses_mrope(self.config):
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.model.embed_tokens(input_ids)
# We need to check for usage of mrope here in case there is
# multimodal data.
# TODO (ywang96): move this to model runner in V1.
if multimodal_embeddings is not None and uses_mrope(self.config):
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_id,
)
if video_input is not None:
video_embeds = self._process_video_input(video_input)
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_id,
)
input_ids = None
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
input_ids = None
hidden_states = self.model(
input_ids=input_ids,

View File

@ -449,10 +449,36 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
return result
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return None
audio_embeddings = self._process_audio_input(audio_input)
return audio_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
# TODO(ywang96): use merge_multimodal_embeddings after
# v0 is deprecated
merge_multimodal_embeddings_from_map(
inputs_embeds, multimodal_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"])
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[torch.Tensor],
intermediate_tensors: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Ultravox
@ -466,30 +492,28 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
Args:
audio_features: A batch of audio inputs [B, N, 80, M].
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is not None:
audio_embeddings = self._process_audio_input(audio_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
merge_multimodal_embeddings_from_map(
inputs_embeds, audio_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"])
input_ids = None
else:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
# TODO(ywang96): remove attn_metadata from get_input_embeddings
# after v0 is deprecated
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings,
attn_metadata)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,

View File

@ -356,8 +356,7 @@ def embed_multimodal(
input_ids: torch.Tensor,
multimodal_token_id: int,
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
get_multimodal_embeds: Callable[[torch.Tensor], Union[torch.Tensor,
List[torch.Tensor]]],
multimodal_embeds: Union[torch.Tensor, List[torch.Tensor]],
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
@ -374,8 +373,6 @@ def embed_multimodal(
is_text = ~is_multimodal
text_embeds = get_text_embeds(input_ids[is_text])
multimodal_embeds = get_multimodal_embeds(input_ids[is_multimodal])
merged_embeds = torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,

View File

@ -363,7 +363,8 @@ class GPUModelRunner:
# 2. A list (length: num_images) of tensors, each of shape
# [feature_size, hidden_size] in case when the feature size is
# dynamic depending on input images.
encoder_outputs = self.model.process_mm_inputs(**batched_mm_inputs)
encoder_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)
# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):