[Model] Use merge_by_field_config for MM models (A-C) (#26073)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Cyrus Leung 2025-10-02 23:17:31 +08:00 committed by yewentao256
parent 0655b90d80
commit 00c0b25e82
5 changed files with 29 additions and 24 deletions

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Optional, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -38,8 +38,8 @@ from .idefics2_vision_model import (
# yapf: enable
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, maybe_prefix)
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
maybe_prefix)
class AriaImagePixelInputs(TensorSchema):
@ -52,6 +52,8 @@ class AriaImagePixelInputs(TensorSchema):
- w: Width of each image
"""
type: Literal["pixel_values"]
pixel_values: Annotated[
torch.Tensor,
TensorShape("bn", 3, "h", "w"),
@ -485,6 +487,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
This model combines a vision tower, a multi-modal projector, and a language
model to perform tasks that involve both image and text inputs.
"""
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
@ -551,12 +555,15 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return None
return AriaImagePixelInputs(
pixel_values=flatten_bn(pixel_values, concat=True),
pixel_mask=flatten_bn(pixel_mask, concat=True),
type="pixel_values",
pixel_values=pixel_values,
pixel_mask=pixel_mask,
)
def _create_patch_attention_mask(
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
self,
pixel_mask: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
if pixel_mask is None:
return None

View File

@ -31,7 +31,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
@ -295,6 +295,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
dummy_inputs=AyaVisionDummyInputsBuilder)
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@ -379,8 +380,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return AyaVisionImagePixelInputs(
type="pixel_values",
pixel_values=flatten_bn(pixel_values, concat=True),
num_patches=flatten_bn(num_patches, concat=True),
pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size,

View File

@ -26,12 +26,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .blip import BlipVisionModel
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix)
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
_IMAGE_TOKEN_ID = 50265
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
class Blip2ImagePixelInputs(TensorSchema):
@ -514,6 +509,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
dummy_inputs=Blip2DummyInputsBuilder)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsQuant):
merge_by_field_config = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@ -570,8 +566,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
if pixel_values is not None:
expected_h = expected_w = self.config.vision_config.image_size
return Blip2ImagePixelInputs(type="pixel_values",
data=flatten_bn(pixel_values,
concat=True),
data=pixel_values,
resolve_bindings={
"h": expected_h,
"w": expected_w
@ -580,7 +575,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
if image_embeds is not None:
return Blip2ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")

View File

@ -42,7 +42,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (flatten_bn, is_pp_missing_parameter,
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -935,6 +935,8 @@ class ChameleonModel(nn.Module):
dummy_inputs=ChameleonDummyInputsBuilder)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsQuant):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
@ -981,8 +983,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
expected_h = expected_w = vq_config.resolution
return ChameleonImagePixelInputs(type="pixel_values",
data=flatten_bn(pixel_values,
concat=True),
data=pixel_values,
resolve_bindings={
"h": expected_h,
"w": expected_w

View File

@ -36,7 +36,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
@ -317,6 +317,7 @@ class Cohere2VisionMultiModalProcessor(
dummy_inputs=Cohere2VisionDummyInputsBuilder)
class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@ -399,8 +400,8 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return Cohere2VisionImagePixelInputs(
type="pixel_values",
pixel_values=flatten_bn(pixel_values, concat=True),
num_patches=flatten_bn(num_patches, concat=True),
pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size,