mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 05:25:01 +08:00
[Bugfix] Fix multi-modal processors for transformers 4.48 (#12187)
This commit is contained in:
parent
4e94951bb1
commit
630eb5b5ce
@ -5,9 +5,11 @@ from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from packaging.version import Version
|
||||||
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
||||||
PixtralVisionConfig, PretrainedConfig,
|
PixtralVisionConfig, PretrainedConfig,
|
||||||
SiglipVisionConfig)
|
SiglipVisionConfig)
|
||||||
|
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||||
from transformers.models.llava import LlavaProcessor
|
from transformers.models.llava import LlavaProcessor
|
||||||
from transformers.models.pixtral import PixtralProcessor
|
from transformers.models.pixtral import PixtralProcessor
|
||||||
|
|
||||||
@ -716,6 +718,27 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
|
||||||
|
class MantisProcessingInfo(LlavaProcessingInfo):
|
||||||
|
|
||||||
|
def get_hf_processor(self):
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vision_info = self.get_vision_encoder_info()
|
||||||
|
|
||||||
|
if Version(TRANSFORMERS_VERSION) < Version("4.48"):
|
||||||
|
# BUG: num_additional_image_tokens = 0 but treated as 1,
|
||||||
|
# so we set vision_feature_select_strategy to None to offset this
|
||||||
|
vision_feature_select_strategy = None
|
||||||
|
else:
|
||||||
|
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
|
||||||
|
vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501
|
||||||
|
|
||||||
|
return self.ctx.get_hf_processor(
|
||||||
|
LlavaProcessor,
|
||||||
|
patch_size=vision_info.get_patch_size(),
|
||||||
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
@ -794,7 +817,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
|||||||
# To use this model, please use
|
# To use this model, please use
|
||||||
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
||||||
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
|
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
|
||||||
info=LlavaProcessingInfo,
|
info=MantisProcessingInfo,
|
||||||
dummy_inputs=LlavaDummyInputsBuilder)
|
dummy_inputs=LlavaDummyInputsBuilder)
|
||||||
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -36,8 +36,9 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
NestedTensors)
|
MultiModalInputsV2, MultiModalKwargs,
|
||||||
|
NestedTensors, PlaceholderRange)
|
||||||
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||||
MultiModalDataParser)
|
MultiModalDataParser)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
@ -153,29 +154,24 @@ class Qwen2AudioMultiModalProcessor(
|
|||||||
mm_data: Mapping[str, object],
|
mm_data: Mapping[str, object],
|
||||||
mm_kwargs: Mapping[str, Any],
|
mm_kwargs: Mapping[str, Any],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
mm_data = dict(mm_data)
|
# Text-only input not supported in composite processor
|
||||||
audios = mm_data.pop("audios", [])
|
if not mm_data or not mm_data.get("audios", []):
|
||||||
|
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
||||||
if audios:
|
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
||||||
mm_data["audios"] = audios
|
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||||
|
|
||||||
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
||||||
mm_kwargs = dict(
|
mm_kwargs = dict(
|
||||||
**mm_kwargs,
|
**mm_kwargs,
|
||||||
sampling_rate=feature_extractor.sampling_rate,
|
sampling_rate=feature_extractor.sampling_rate,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
|
|
||||||
pass
|
|
||||||
|
|
||||||
processed_outputs = super()._call_hf_processor(
|
return super()._call_hf_processor(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
mm_data=mm_data,
|
mm_data=mm_data,
|
||||||
mm_kwargs=mm_kwargs,
|
mm_kwargs=mm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return processed_outputs
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
@ -192,8 +188,14 @@ class Qwen2AudioMultiModalProcessor(
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
hf_config = self.info.get_hf_config()
|
processor = self.info.get_hf_processor()
|
||||||
placeholder = hf_config.audio_token_index
|
|
||||||
|
# Use getattr with default to be compatible with transformers<4.48
|
||||||
|
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
|
||||||
|
audio_bos_token = getattr(processor, "audio_bos_token",
|
||||||
|
"<|audio_bos|>")
|
||||||
|
audio_eos_token = getattr(processor, "audio_eos_token",
|
||||||
|
"<|audio_eos|>")
|
||||||
|
|
||||||
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
|
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
|
||||||
if feature_attention_mask is None:
|
if feature_attention_mask is None:
|
||||||
@ -214,12 +216,16 @@ class Qwen2AudioMultiModalProcessor(
|
|||||||
f"The audio {audio} (len={len(audio)}) is too short "
|
f"The audio {audio} (len={len(audio)}) is too short "
|
||||||
"to be represented inside the model")
|
"to be represented inside the model")
|
||||||
|
|
||||||
return [placeholder] * num_placeholders
|
return "".join([
|
||||||
|
audio_bos_token,
|
||||||
|
audio_token * num_placeholders,
|
||||||
|
audio_eos_token,
|
||||||
|
])
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
modality="audio",
|
modality="audio",
|
||||||
target=[placeholder],
|
target=audio_token,
|
||||||
replacement=get_replacement_qwen2_audio,
|
replacement=get_replacement_qwen2_audio,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -234,6 +240,26 @@ class Qwen2AudioMultiModalProcessor(
|
|||||||
# tokens than the number of audio items)
|
# tokens than the number of audio items)
|
||||||
return not hasattr(self.info.get_hf_processor(), "audio_token")
|
return not hasattr(self.info.get_hf_processor(), "audio_token")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, list[int]],
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> MultiModalInputsV2:
|
||||||
|
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||||
|
|
||||||
|
# Only <|AUDIO|> tokens should be considered as placeholders,
|
||||||
|
# so we ignore the audio_bos_token and audio_eos_token
|
||||||
|
result["mm_placeholders"] = {
|
||||||
|
modality: [
|
||||||
|
PlaceholderRange(offset=p["offset"] + 1,
|
||||||
|
length=p["length"] - 2) for p in ps
|
||||||
|
]
|
||||||
|
for modality, ps in result["mm_placeholders"].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
Qwen2AudioMultiModalProcessor,
|
Qwen2AudioMultiModalProcessor,
|
||||||
|
|||||||
@ -137,7 +137,7 @@ class UltravoxMultiModalProcessor(
|
|||||||
mm_kwargs: Mapping[str, object],
|
mm_kwargs: Mapping[str, object],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
# Text-only input not supported in composite processor
|
# Text-only input not supported in composite processor
|
||||||
if not mm_data:
|
if not mm_data or not mm_data.get("audios", []):
|
||||||
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
||||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
||||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||||
@ -146,13 +146,6 @@ class UltravoxMultiModalProcessor(
|
|||||||
audios = mm_data.pop("audios", [])
|
audios = mm_data.pop("audios", [])
|
||||||
assert isinstance(audios, list)
|
assert isinstance(audios, list)
|
||||||
|
|
||||||
if not audios:
|
|
||||||
return super()._call_hf_processor(
|
|
||||||
prompt=prompt,
|
|
||||||
mm_data=mm_data,
|
|
||||||
mm_kwargs=mm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
feature_extractor = self.info.get_feature_extractor()
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
mm_kwargs = dict(
|
mm_kwargs = dict(
|
||||||
**mm_kwargs,
|
**mm_kwargs,
|
||||||
|
|||||||
@ -22,10 +22,10 @@ from vllm.envs import VLLM_USE_MODELSCOPE
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
from vllm.transformers_utils.configs import (AriaConfig, ChatGLMConfig,
|
||||||
DbrxConfig, DeepseekVLV2Config,
|
Cohere2Config, DbrxConfig,
|
||||||
EAGLEConfig, ExaoneConfig,
|
DeepseekVLV2Config, EAGLEConfig,
|
||||||
H2OVLChatConfig,
|
ExaoneConfig, H2OVLChatConfig,
|
||||||
InternVLChatConfig, JAISConfig,
|
InternVLChatConfig, JAISConfig,
|
||||||
MedusaConfig, MllamaConfig,
|
MedusaConfig, MllamaConfig,
|
||||||
MLPSpeculatorConfig, MPTConfig,
|
MLPSpeculatorConfig, MPTConfig,
|
||||||
@ -52,6 +52,7 @@ _CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||||
|
"aria": AriaConfig,
|
||||||
"chatglm": ChatGLMConfig,
|
"chatglm": ChatGLMConfig,
|
||||||
"cohere2": Cohere2Config,
|
"cohere2": Cohere2Config,
|
||||||
"dbrx": DbrxConfig,
|
"dbrx": DbrxConfig,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from vllm.transformers_utils.configs.aria import AriaConfig
|
||||||
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
||||||
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
|
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
|
||||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||||
@ -23,6 +24,7 @@ from vllm.transformers_utils.configs.telechat2 import Telechat2Config
|
|||||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AriaConfig",
|
||||||
"ChatGLMConfig",
|
"ChatGLMConfig",
|
||||||
"Cohere2Config",
|
"Cohere2Config",
|
||||||
"DbrxConfig",
|
"DbrxConfig",
|
||||||
|
|||||||
@ -1,7 +1,32 @@
|
|||||||
|
# Copyright 2024 Rhymes AI. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
from transformers.models.idefics2.configuration_idefics2 import (
|
from transformers.models.idefics2.configuration_idefics2 import (
|
||||||
Idefics2VisionConfig)
|
Idefics2VisionConfig)
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AriaVisionConfig(Idefics2VisionConfig):
|
class AriaVisionConfig(Idefics2VisionConfig):
|
||||||
model_type = "aria_vision_model"
|
model_type = "aria_vision_model"
|
||||||
@ -45,3 +70,96 @@ class AriaMoELMConfig(LlamaConfig):
|
|||||||
self.moe_num_experts = moe_num_experts
|
self.moe_num_experts = moe_num_experts
|
||||||
self.moe_topk = moe_topk
|
self.moe_topk = moe_topk
|
||||||
self.moe_num_shared_experts = moe_num_shared_experts
|
self.moe_num_shared_experts = moe_num_shared_experts
|
||||||
|
|
||||||
|
|
||||||
|
class AriaConfig(PretrainedConfig):
|
||||||
|
"""
|
||||||
|
Configuration class for Aria model.
|
||||||
|
This class handles the configuration for both vision and text components of
|
||||||
|
the Aria model,
|
||||||
|
as well as additional parameters for image token handling and projector
|
||||||
|
mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_config (AriaVisionConfig or dict): Configuration for the vision
|
||||||
|
component.
|
||||||
|
text_config (AriaMoELMConfig or dict): Configuration for the text
|
||||||
|
component.
|
||||||
|
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
|
||||||
|
dimensions.
|
||||||
|
ignore_index (int): Index to ignore in loss calculation.
|
||||||
|
image_token_index (int): Index used to represent image tokens.
|
||||||
|
**kwargs: Additional keyword arguments passed to the parent class.
|
||||||
|
Attributes:
|
||||||
|
model_type (str): Type of the model, set to "aria".
|
||||||
|
is_composition (bool): Whether the model is a composition of multiple
|
||||||
|
components.
|
||||||
|
ignore_index (int): Index to ignore in loss calculation.
|
||||||
|
image_token_index (int): Index used to represent image tokens.
|
||||||
|
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
|
||||||
|
dimensions.
|
||||||
|
vision_config (AriaVisionConfig): Configuration for the vision
|
||||||
|
component.
|
||||||
|
text_config (AriaMoELMConfig): Configuration for the text component.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "aria"
|
||||||
|
is_composition = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_config: AriaVisionConfig = AriaVisionConfig(), # noqa: B008
|
||||||
|
text_config: AriaMoELMConfig = AriaMoELMConfig(), # noqa: B008
|
||||||
|
projector_patch_to_query_dict: Mapping[int, int] = {
|
||||||
|
1225: 128,
|
||||||
|
4900: 256,
|
||||||
|
},
|
||||||
|
ignore_index=-100,
|
||||||
|
image_token_index=32000,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.image_token_index = image_token_index
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
attn_implementation = kwargs.pop("attn_implementation", None)
|
||||||
|
|
||||||
|
# Set the default attention implementation to flash_attention_2 if not
|
||||||
|
# specified
|
||||||
|
self._attn_implementation = ("flash_attention_2"
|
||||||
|
if attn_implementation is None else
|
||||||
|
attn_implementation)
|
||||||
|
|
||||||
|
# Convert the keys and values of projector_patch_to_query_dict to
|
||||||
|
# integers
|
||||||
|
# This ensures consistency even if they were provided as strings
|
||||||
|
self.projector_patch_to_query_dict = {
|
||||||
|
int(k): int(v)
|
||||||
|
for k, v in projector_patch_to_query_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(vision_config, dict) and "model_type" in vision_config:
|
||||||
|
vision_config = AriaVisionConfig(**vision_config)
|
||||||
|
if attn_implementation is None:
|
||||||
|
vision_attn_implementation = "flash_attention_2"
|
||||||
|
elif attn_implementation == "sdpa":
|
||||||
|
logger.warning("SDPA is not supported for vit, using "
|
||||||
|
"flash_attention_2 instead")
|
||||||
|
vision_attn_implementation = "flash_attention_2"
|
||||||
|
else:
|
||||||
|
vision_attn_implementation = attn_implementation
|
||||||
|
vision_config._attn_implementation = vision_attn_implementation
|
||||||
|
|
||||||
|
self.vision_config = vision_config
|
||||||
|
|
||||||
|
if isinstance(text_config, dict) and "model_type" in text_config:
|
||||||
|
text_attn_implementation = ("sdpa" if attn_implementation is None
|
||||||
|
else attn_implementation)
|
||||||
|
text_config = AriaMoELMConfig(**text_config)
|
||||||
|
text_config._attn_implementation = text_attn_implementation
|
||||||
|
|
||||||
|
self.text_config = text_config
|
||||||
|
|
||||||
|
# This is needed for the static kv cache
|
||||||
|
self.num_hidden_layers = self.text_config.num_hidden_layers
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user