mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 03:45:02 +08:00
[Model] Use merge_by_field_config for MM models (A-C) (#26073)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
0655b90d80
commit
00c0b25e82
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Annotated, Optional, Union
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -38,8 +38,8 @@ from .idefics2_vision_model import (
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
|
||||||
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
|
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||||
is_pp_missing_parameter, maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
|
|
||||||
class AriaImagePixelInputs(TensorSchema):
|
class AriaImagePixelInputs(TensorSchema):
|
||||||
@ -52,6 +52,8 @@ class AriaImagePixelInputs(TensorSchema):
|
|||||||
- w: Width of each image
|
- w: Width of each image
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
|
||||||
pixel_values: Annotated[
|
pixel_values: Annotated[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
TensorShape("bn", 3, "h", "w"),
|
TensorShape("bn", 3, "h", "w"),
|
||||||
@ -485,6 +487,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
This model combines a vision tower, a multi-modal projector, and a language
|
This model combines a vision tower, a multi-modal projector, and a language
|
||||||
model to perform tasks that involve both image and text inputs.
|
model to perform tasks that involve both image and text inputs.
|
||||||
"""
|
"""
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
# mapping for new names in checkpoint saved after transformers v4.52
|
# mapping for new names in checkpoint saved after transformers v4.52
|
||||||
@ -551,12 +555,15 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return AriaImagePixelInputs(
|
return AriaImagePixelInputs(
|
||||||
pixel_values=flatten_bn(pixel_values, concat=True),
|
type="pixel_values",
|
||||||
pixel_mask=flatten_bn(pixel_mask, concat=True),
|
pixel_values=pixel_values,
|
||||||
|
pixel_mask=pixel_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_patch_attention_mask(
|
def _create_patch_attention_mask(
|
||||||
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
self,
|
||||||
|
pixel_mask: Optional[torch.Tensor],
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
if pixel_mask is None:
|
if pixel_mask is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
init_vllm_registered_model, maybe_prefix)
|
init_vllm_registered_model, maybe_prefix)
|
||||||
|
|
||||||
|
|
||||||
@ -295,6 +295,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
|||||||
dummy_inputs=AyaVisionDummyInputsBuilder)
|
dummy_inputs=AyaVisionDummyInputsBuilder)
|
||||||
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP):
|
SupportsPP):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
@ -379,8 +380,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
return AyaVisionImagePixelInputs(
|
return AyaVisionImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=flatten_bn(pixel_values, concat=True),
|
pixel_values=pixel_values,
|
||||||
num_patches=flatten_bn(num_patches, concat=True),
|
num_patches=num_patches,
|
||||||
resolve_bindings={
|
resolve_bindings={
|
||||||
"h": self.config.vision_config.image_size,
|
"h": self.config.vision_config.image_size,
|
||||||
"w": self.config.vision_config.image_size,
|
"w": self.config.vision_config.image_size,
|
||||||
|
|||||||
@ -26,12 +26,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .blip import BlipVisionModel
|
from .blip import BlipVisionModel
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||||
SupportsQuant)
|
SupportsQuant)
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||||
maybe_prefix)
|
|
||||||
|
|
||||||
# We use this internally as placeholders since there is no image token
|
|
||||||
# defined on the HuggingFace repo
|
|
||||||
_IMAGE_TOKEN_ID = 50265
|
|
||||||
|
|
||||||
|
|
||||||
class Blip2ImagePixelInputs(TensorSchema):
|
class Blip2ImagePixelInputs(TensorSchema):
|
||||||
@ -514,6 +509,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
|
|||||||
dummy_inputs=Blip2DummyInputsBuilder)
|
dummy_inputs=Blip2DummyInputsBuilder)
|
||||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||||
SupportsQuant):
|
SupportsQuant):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||||
@ -570,8 +566,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
expected_h = expected_w = self.config.vision_config.image_size
|
expected_h = expected_w = self.config.vision_config.image_size
|
||||||
return Blip2ImagePixelInputs(type="pixel_values",
|
return Blip2ImagePixelInputs(type="pixel_values",
|
||||||
data=flatten_bn(pixel_values,
|
data=pixel_values,
|
||||||
concat=True),
|
|
||||||
resolve_bindings={
|
resolve_bindings={
|
||||||
"h": expected_h,
|
"h": expected_h,
|
||||||
"w": expected_w
|
"w": expected_w
|
||||||
@ -580,7 +575,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
return Blip2ImageEmbeddingInputs(
|
return Blip2ImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
data=flatten_bn(image_embeds, concat=True),
|
data=image_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|||||||
@ -42,7 +42,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||||
SupportsQuant)
|
SupportsQuant)
|
||||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
from .utils import (is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -935,6 +935,8 @@ class ChameleonModel(nn.Module):
|
|||||||
dummy_inputs=ChameleonDummyInputsBuilder)
|
dummy_inputs=ChameleonDummyInputsBuilder)
|
||||||
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP, SupportsQuant):
|
SupportsPP, SupportsQuant):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||||
@ -981,8 +983,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
expected_h = expected_w = vq_config.resolution
|
expected_h = expected_w = vq_config.resolution
|
||||||
|
|
||||||
return ChameleonImagePixelInputs(type="pixel_values",
|
return ChameleonImagePixelInputs(type="pixel_values",
|
||||||
data=flatten_bn(pixel_values,
|
data=pixel_values,
|
||||||
concat=True),
|
|
||||||
resolve_bindings={
|
resolve_bindings={
|
||||||
"h": expected_h,
|
"h": expected_h,
|
||||||
"w": expected_w
|
"w": expected_w
|
||||||
|
|||||||
@ -36,7 +36,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
init_vllm_registered_model, maybe_prefix)
|
init_vllm_registered_model, maybe_prefix)
|
||||||
|
|
||||||
|
|
||||||
@ -317,6 +317,7 @@ class Cohere2VisionMultiModalProcessor(
|
|||||||
dummy_inputs=Cohere2VisionDummyInputsBuilder)
|
dummy_inputs=Cohere2VisionDummyInputsBuilder)
|
||||||
class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP):
|
SupportsPP):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
@ -399,8 +400,8 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
return Cohere2VisionImagePixelInputs(
|
return Cohere2VisionImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=flatten_bn(pixel_values, concat=True),
|
pixel_values=pixel_values,
|
||||||
num_patches=flatten_bn(num_patches, concat=True),
|
num_patches=num_patches,
|
||||||
resolve_bindings={
|
resolve_bindings={
|
||||||
"h": self.config.vision_config.image_size,
|
"h": self.config.vision_config.image_size,
|
||||||
"w": self.config.vision_config.image_size,
|
"w": self.config.vision_config.image_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user