mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 14:23:07 +08:00
[V0 Deprecation][Models] Remove all V0 condition for mm embeddings merge (#25331)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
4079a63a86
commit
b765adccd7
@ -427,17 +427,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
@ -672,17 +672,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -1014,18 +1014,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
image_token_id = self.model.vocabulary_mapping.image_token_id
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -440,17 +440,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
@ -614,17 +614,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -352,17 +352,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
@ -596,25 +596,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
if (vision_embeddings is not None) and len(vision_embeddings) != 0:
|
||||
kwargs = self.prepare_attn_masks(
|
||||
input_ids,
|
||||
positions,
|
||||
mask_dtype=self.dtype,
|
||||
**kwargs,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -71,7 +71,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from ..layers.activation import SiluAndMul
|
||||
@ -80,8 +79,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
from .qwen2_vl import (_create_qwen2vl_field_factory,
|
||||
apply_rotary_pos_emb_vision)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1552,32 +1550,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
multimodal_embeddings += video_embeddings
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_input: Optional[Glm4vImageInputs] = None,
|
||||
video_input: Optional[Glm4vVideoInputs] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
if image_input is not None:
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
image_embeds,
|
||||
placeholder_token_id=self.config.image_token_id,
|
||||
)
|
||||
|
||||
if video_input is not None:
|
||||
video_embeds = self._process_video_input(video_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
video_embeds,
|
||||
placeholder_token_id=self.config.video_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1604,26 +1576,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
|
||||
if image_input is None and video_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
if uses_mrope(self.config):
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}")
|
||||
inputs_embeds = self.get_input_embeddings_v0(
|
||||
input_ids,
|
||||
image_input=image_input,
|
||||
video_input=video_input)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import flatten_bn, isin_list
|
||||
from .utils import flatten_bn
|
||||
|
||||
|
||||
class GLMVImagePixelInputs(TensorSchema):
|
||||
@ -618,21 +618,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, [
|
||||
self.config.boi_token_id,
|
||||
self.config.pad_token_id,
|
||||
self.config.eoi_token_id,
|
||||
]),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
|
||||
@ -765,17 +765,6 @@ class GraniteSpeechForConditionalGeneration(
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
audio_embeds = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
audio_embeds,
|
||||
is_multimodal=input_ids == self.config.audio_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
model_output = self.language_model(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return model_output
|
||||
|
||||
@ -45,8 +45,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model, isin_list,
|
||||
maybe_prefix)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
EOT = "<|endofturn|>"
|
||||
@ -747,18 +746,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=isin_list(
|
||||
input_ids,
|
||||
[self.config.image_token_id, self.config.video_token_id]),
|
||||
)
|
||||
input_ids = None
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -702,17 +702,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model.text_model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -40,7 +40,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, isin_list, maybe_prefix)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
|
||||
class InternS1MultiModalProjector(nn.Module):
|
||||
@ -798,22 +798,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, context_token_ids),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
|
||||
@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
isin_list, maybe_prefix)
|
||||
maybe_prefix)
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -1371,22 +1371,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, context_token_ids),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
|
||||
@ -433,22 +433,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
image_embeds,
|
||||
is_multimodal=input_ids ==
|
||||
self.config.media_placeholder_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids=input_ids,
|
||||
|
||||
@ -723,17 +723,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -547,17 +547,6 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens].
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -431,17 +431,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.video_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -30,8 +30,7 @@ from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
|
||||
LlavaNextProcessingInfo)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
@ -850,33 +849,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_input: Optional[LlavaOnevisionImagePixelInputs] = None,
|
||||
video_input: Optional[LlavaOnevisionVideoPixelInputs] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
if image_input is not None:
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
image_embeds,
|
||||
placeholder_token_id=self.config.image_token_index,
|
||||
)
|
||||
|
||||
if video_input is not None:
|
||||
video_embeds = self._process_video_pixels(video_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
video_embeds,
|
||||
placeholder_token_id=self.config.video_token_index,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -894,22 +866,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
|
||||
if image_input is None and video_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
inputs_embeds = self.get_input_embeddings_v0(
|
||||
input_ids,
|
||||
image_input=image_input,
|
||||
video_input=video_input)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -71,7 +71,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import AutoWeightsLoader, flatten_bn, isin_list, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
@ -1154,19 +1154,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, list(self.mm_token_ids)),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.llm.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
@ -571,17 +571,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -823,17 +823,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner,
|
||||
# this condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
return self.language_model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
|
||||
@ -1490,17 +1490,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.img_patch_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -35,7 +35,7 @@ from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
|
||||
from vllm.model_executor.models.radio import RadioModel
|
||||
from vllm.model_executor.models.utils import (flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
isin_list, maybe_prefix)
|
||||
maybe_prefix)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargs, MultiModalKwargsItems,
|
||||
@ -1135,22 +1135,6 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, context_token_ids),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
@ -608,17 +608,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.img_context_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
|
||||
@ -511,17 +511,6 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.image_pad_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
# up until here we have an inputs_embeds 100% numerical identity
|
||||
# between the OG HF Transformers implementation and ours
|
||||
hidden_states = self.llm(
|
||||
|
||||
@ -596,18 +596,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.image_pad_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
# up until here we have a inputs_embeds 100% numerical identity
|
||||
# between the OG HF Transformers implementation and ours
|
||||
hidden_states = self.llm(
|
||||
|
||||
@ -370,17 +370,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -679,17 +679,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=self.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -1411,22 +1411,6 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
||||
if image_input is None and audio_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
inputs_embeds = self.get_input_embeddings_v0(
|
||||
input_ids,
|
||||
image_input=image_input,
|
||||
audio_input=audio_input)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids,
|
||||
positions,
|
||||
|
||||
@ -35,8 +35,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||
from .phi4mm_audio import AudioEmbedding
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
|
||||
|
||||
# <|endoftext10|> (see vocab.json in hf model)
|
||||
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
|
||||
@ -1174,35 +1173,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_input: Optional[Phi4MMImagePixelInputs] = None,
|
||||
audio_input: Optional[Phi4MMAudioFeatureInputs] = None,
|
||||
) -> torch.Tensor:
|
||||
audio_projection_mode = 'speech'
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
if image_input is not None:
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
image_embeds,
|
||||
placeholder_token_id=_IMAGE_PLACEHOLDER_TOKEN_ID,
|
||||
)
|
||||
audio_projection_mode = 'vision'
|
||||
|
||||
if audio_input is not None:
|
||||
audio_embeds = self._process_audio_input(
|
||||
audio_input, audio_projection_mode=audio_projection_mode)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
audio_embeds,
|
||||
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN_ID,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1214,22 +1184,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
||||
if image_input is None and audio_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
inputs_embeds = self.get_input_embeddings_v0(
|
||||
input_ids,
|
||||
image_input=image_input,
|
||||
audio_input=audio_input)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
|
||||
@ -444,17 +444,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.vision_args.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -69,8 +69,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
@ -908,26 +907,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
multimodal_embeddings.append((video_embeds, "video"))
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
for embeddings, modality in multimodal_embeddings:
|
||||
if modality == "audio":
|
||||
placeholder_token_id = self.config.audio_token_index
|
||||
if modality == "image":
|
||||
placeholder_token_id = self.config.image_token_index
|
||||
if modality == "video":
|
||||
placeholder_token_id = self.config.video_token_index
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, embeddings, placeholder_token_id)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -939,14 +918,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings_v0(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings_v0(
|
||||
input_ids, multimodal_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -64,7 +64,6 @@ from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@ -75,8 +74,7 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
|
||||
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
|
||||
apply_rotary_pos_emb_vision)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1365,40 +1363,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
multimodal_embeddings += video_embeddings
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_input: Optional[Qwen2_5_VLImageInputs] = None,
|
||||
video_input: Optional[Qwen2_5_VLVideoInputs] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
if image_input is not None:
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
if self.is_multimodal_pruning_enabled:
|
||||
image_embeds = self._postprocess_image_embeds_evs(
|
||||
image_embeds, image_input
|
||||
)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
image_embeds,
|
||||
placeholder_token_id=self.config.image_token_id,
|
||||
)
|
||||
|
||||
if video_input is not None:
|
||||
video_embeds = self._process_video_input(video_input)
|
||||
if self.is_multimodal_pruning_enabled:
|
||||
video_embeds = self._postprocess_video_embeds_evs(
|
||||
video_embeds, video_input
|
||||
)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
video_embeds,
|
||||
placeholder_token_id=self.config.video_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1421,26 +1385,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
|
||||
if image_input is None and video_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
if uses_mrope(self.config):
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}")
|
||||
inputs_embeds = self.get_input_embeddings_v0(
|
||||
input_ids,
|
||||
image_input=image_input,
|
||||
video_input=video_input)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
@ -449,17 +449,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=input_ids == self.config.audio_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
@ -65,15 +65,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1464,32 +1462,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_input: Optional[Qwen2VLImagePixelInputs] = None,
|
||||
video_input: Optional[Qwen2VLVideoPixelInputs] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
if image_input is not None:
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
image_embeds,
|
||||
placeholder_token_id=self.config.image_token_id,
|
||||
)
|
||||
|
||||
if video_input is not None:
|
||||
video_embeds = self._process_video_input(video_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
video_embeds,
|
||||
placeholder_token_id=self.config.video_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1515,26 +1487,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
|
||||
if image_input is None and video_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
if uses_mrope(self.config):
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}")
|
||||
inputs_embeds = self.get_input_embeddings_v0(
|
||||
input_ids,
|
||||
image_input=image_input,
|
||||
video_input=video_input)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
@ -68,7 +68,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
@ -82,8 +81,7 @@ from .qwen2_5_vl import (Qwen2_5_VisionAttention,
|
||||
from .qwen2_vl import Qwen2VLProcessingInfo
|
||||
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
_merge_multimodal_embeddings, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
_merge_multimodal_embeddings, maybe_prefix)
|
||||
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1464,75 +1462,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_input: Optional[Qwen2_5_VLImageInputs] = None,
|
||||
video_input: Optional[Qwen2_5_VLVideoInputs] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
|
||||
if self.use_deepstack:
|
||||
visual_dim = inputs_embeds.shape[-1]
|
||||
deepstack_input_embeds = None
|
||||
if image_input is not None or video_input is not None:
|
||||
deepstack_input_embeds = torch.zeros_like(
|
||||
inputs_embeds).unsqueeze(1).repeat(
|
||||
1, self.deepstack_num_level, 1).flatten(1)
|
||||
|
||||
if image_input is not None:
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
if self.use_deepstack:
|
||||
image_embeds = torch.cat(image_embeds)
|
||||
|
||||
image_embeds, image_embeds_multiscale = image_embeds.split(
|
||||
[visual_dim, visual_dim * self.deepstack_num_level],
|
||||
dim=-1)
|
||||
|
||||
deepstack_input_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
deepstack_input_embeds,
|
||||
image_embeds_multiscale,
|
||||
placeholder_token_id=self.config.image_token_id,
|
||||
)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
image_embeds,
|
||||
placeholder_token_id=self.config.image_token_id,
|
||||
)
|
||||
|
||||
if video_input is not None:
|
||||
video_embeds = self._process_video_input(video_input)
|
||||
if self.use_deepstack:
|
||||
video_embeds = torch.cat(video_embeds)
|
||||
|
||||
video_embeds, video_embeds_multiscale = video_embeds.split(
|
||||
[visual_dim, visual_dim * self.deepstack_num_level],
|
||||
dim=-1)
|
||||
|
||||
deepstack_input_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
deepstack_input_embeds,
|
||||
video_embeds_multiscale,
|
||||
placeholder_token_id=self.config.video_token_id,
|
||||
)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
video_embeds,
|
||||
placeholder_token_id=self.config.video_token_id,
|
||||
)
|
||||
|
||||
if self.use_deepstack and deepstack_input_embeds is not None:
|
||||
deepstack_input_embeds = deepstack_input_embeds.view(
|
||||
inputs_embeds.shape[0], self.deepstack_num_level,
|
||||
visual_dim).permute(1, 0, 2).contiguous()
|
||||
self._set_deepstack_input_embeds(deepstack_input_embeds)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1568,26 +1497,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
|
||||
if image_input is None and video_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
if uses_mrope(self.config):
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}")
|
||||
inputs_embeds = self.get_input_embeddings_v0(
|
||||
input_ids,
|
||||
image_input=image_input,
|
||||
video_input=video_input)
|
||||
input_ids = None
|
||||
|
||||
if self.use_deepstack and inputs_embeds is not None and get_pp_group(
|
||||
).is_first_rank:
|
||||
deepstack_input_embeds = self._get_deepstack_input_embeds(
|
||||
|
||||
@ -767,18 +767,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids ==
|
||||
self.transformer.visual.image_pad_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
@ -874,17 +874,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.img_context_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
|
||||
@ -881,19 +881,6 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||
# condition is only for v0 compatibility.
|
||||
if inputs_embeds is None:
|
||||
multimodal_embeds = self.get_multimodal_embeddings(**kwargs)
|
||||
if multimodal_embeds is not None:
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeds,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
model_output = super().forward(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return model_output
|
||||
|
||||
@ -597,18 +597,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=input_ids == self.config.audio_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
language_model = self.language_model
|
||||
if hasattr(language_model, "language_model"):
|
||||
language_model = language_model.language_model
|
||||
|
||||
@ -371,19 +371,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
audio_encoder = self.tokenizer.instruct.audio_encoder
|
||||
audio_tok_id = audio_encoder.audio_token
|
||||
audio_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
audio_embeddings,
|
||||
is_multimodal=input_ids == audio_tok_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user