[Bugfix][VLM] Fix mixed-modality inference backward compatibility for V0 (#12313)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2025-01-22 05:06:36 -08:00 committed by GitHub
parent 528dbcac7d
commit 16366ee8bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 91 additions and 27 deletions

View File

@ -816,7 +816,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return image_feature return image_feature
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return None return None
@ -842,8 +842,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[List[Tuple[NestedTensors, multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None,
str]]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
@ -852,6 +851,34 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
[self.config.image_token_index, self.config.video_token_index]) [self.config.image_token_index, self.config.video_token_index])
return inputs_embeds return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
image_input: Optional[NestedTensors] = None,
video_input: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_index,
)
if video_input is not None:
video_embeds = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_index,
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -871,13 +898,21 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner from
# condition is for v0 compatibility. # `get_multimodal_embeddings` and `get_input_embeddings`, this
# condition is only for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, video_input = self._parse_and_validate_video_input(**kwargs)
multimodal_embeddings)
input_ids = None if image_input is None and video_input is None:
inputs_embeds = None
else:
inputs_embeds = self.get_input_embeddings_v0(
input_ids,
image_input=image_input,
video_input=video_input)
input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,

View File

@ -55,7 +55,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData, from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs, MultiModalFieldConfig, MultiModalKwargs,
NestedTensors, VideoItem) VideoItem)
from vllm.multimodal.parse import (ImageSize, ModalityDataItems, from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataItems, MultiModalDataParser) MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -1233,7 +1233,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return modalities return modalities
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
@ -1260,8 +1260,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[List[Tuple[NestedTensors, multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None,
str]]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
@ -1270,6 +1269,33 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
[self.config.image_token_id, self.config.video_token_id]) [self.config.image_token_id, self.config.video_token_id])
return inputs_embeds return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
image_input: Optional[tuple[torch.Tensor, ...]] = None,
video_input: Optional[tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_id,
)
if video_input is not None:
video_embeds = self._process_video_input(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_id,
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -1303,22 +1329,25 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner from
# condition is for v0 compatibility. # `get_multimodal_embeddings` and `get_input_embeddings`, this
# condition is only for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
# We need to check for usage of mrope here in case there is if image_input is None and video_input is None:
# multimodal data. inputs_embeds = None
# TODO (ywang96): move this to model runner in V1. else:
if multimodal_embeddings is not None and uses_mrope(self.config): if uses_mrope(self.config):
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}") f"(3, seq_len) positions, but got {positions.size()}")
inputs_embeds = self.get_input_embeddings_v0(
inputs_embeds = self.get_input_embeddings(input_ids, input_ids,
multimodal_embeddings) image_input=image_input,
input_ids = None video_input=video_input)
input_ids = None
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
input_ids=input_ids, input_ids=input_ids,