mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-22 10:54:29 +08:00
[Misc] Update config loading for Qwen2-VL and remove Granite (#8837)
This commit is contained in:
parent
7193774b1f
commit
4bb98f2190
@ -280,7 +280,7 @@ Multimodal Language Models
|
||||
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
|
||||
-
|
||||
* - :code:`Qwen2VLForConditionalGeneration`
|
||||
- Qwen2-VL (see note)
|
||||
- Qwen2-VL
|
||||
- Image\ :sup:`+` / Video\ :sup:`+`
|
||||
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
|
||||
-
|
||||
@ -297,15 +297,6 @@ Multimodal Language Models
|
||||
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
||||
|
||||
.. note::
|
||||
For :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
|
||||
This can be installed by running the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830
|
||||
|
||||
----
|
||||
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
|
||||
|
||||
@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GraniteConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
@ -48,7 +49,6 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.granite import GraniteConfig
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
@ -31,12 +31,9 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from transformers import Qwen2VLConfig
|
||||
from transformers.image_utils import (get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
to_numpy_array)
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
||||
Qwen2VLVisionConfig)
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
||||
make_batched_images, make_batched_videos, smart_resize)
|
||||
|
||||
@ -66,6 +63,8 @@ from vllm.multimodal.base import MultiModalData
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
|
||||
Qwen2VLVisionConfig)
|
||||
from vllm.transformers_utils.processor import get_processor
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
|
||||
@ -20,10 +20,10 @@ from vllm.logger import init_logger
|
||||
# yapf: disable
|
||||
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
EAGLEConfig, ExaoneConfig,
|
||||
GraniteConfig, InternVLChatConfig,
|
||||
JAISConfig, MedusaConfig,
|
||||
MllamaConfig, MLPSpeculatorConfig,
|
||||
MPTConfig, NemotronConfig,
|
||||
InternVLChatConfig, JAISConfig,
|
||||
MedusaConfig, MllamaConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
NemotronConfig, Qwen2VLConfig,
|
||||
RWConfig, SolarConfig,
|
||||
UltravoxConfig)
|
||||
# yapf: enable
|
||||
@ -57,9 +57,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"nemotron": NemotronConfig,
|
||||
"solar": SolarConfig,
|
||||
"ultravox": UltravoxConfig,
|
||||
# Granite can be removed from here once we have upgraded to
|
||||
# transformers 4.45+
|
||||
"granite": GraniteConfig,
|
||||
"qwen2_vl": Qwen2VLConfig,
|
||||
**_CONFIG_REGISTRY_OVERRIDE_HF
|
||||
}
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||
from vllm.transformers_utils.configs.granite import GraniteConfig
|
||||
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
|
||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||
@ -14,6 +13,8 @@ from vllm.transformers_utils.configs.mllama import MllamaConfig
|
||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
|
||||
Qwen2VLVisionConfig)
|
||||
from vllm.transformers_utils.configs.solar import SolarConfig
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
|
||||
@ -32,7 +33,6 @@ __all__ = [
|
||||
"NemotronConfig",
|
||||
"SolarConfig",
|
||||
"UltravoxConfig",
|
||||
# Granite can be removed from here once we have upgraded to
|
||||
# transformers 4.45+
|
||||
"GraniteConfig",
|
||||
"Qwen2VLConfig",
|
||||
"Qwen2VLVisionConfig",
|
||||
]
|
||||
|
||||
@ -1,199 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Granite model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GraniteConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of
|
||||
a [`GraniteModel`]. It is used to instantiate an Granite
|
||||
model according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration with the defaults will yield a similar
|
||||
configuration to that of the Granite-3B.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to
|
||||
control the model outputs. Read the documentation from [`PretrainedConfig`]
|
||||
for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the Granite model. Defines the number of
|
||||
different tokens that can be represented by the `inputs_ids`
|
||||
passed when calling [`GraniteModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the
|
||||
Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to
|
||||
implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi
|
||||
Head Attention (MHA), if `num_key_value_heads=1` the model will use
|
||||
Multi Query Attention (MQA) otherwise GQA is used. When converting
|
||||
a multi-head checkpoint to a GQA checkpoint, each group key and
|
||||
value head should be constructed by meanpooling all the original
|
||||
heads within that group. For more details checkout
|
||||
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not
|
||||
specified, will default to `num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the
|
||||
decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values
|
||||
attentions (not used by all models). Only relevant if
|
||||
`config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE
|
||||
embeddings. Currently supports two scaling strategies: linear and
|
||||
dynamic. Their scaling factor must be a float greater than 1. The
|
||||
expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`.
|
||||
When using this flag, don't update `max_position_embeddings` to
|
||||
the expected new maximum. See the following thread for more
|
||||
information on how these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/.
|
||||
This is an experimental feature, subject to breaking API changes
|
||||
in future versions.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output
|
||||
projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in up_proj, down_proj and gate_proj layers
|
||||
in the MLP layers.
|
||||
embedding_multiplier (`float`, *optional*, defaults to 1.0):
|
||||
embedding multiplier
|
||||
logits_scaling (`float`, *optional*, defaults to 1.0):
|
||||
divisor for output logits
|
||||
residual_multiplier (`float`, *optional*, defaults to 1.0):
|
||||
residual multiplier
|
||||
attention_multiplier (`float`, *optional*, defaults to 1.0):
|
||||
attention multiplier
|
||||
|
||||
```python
|
||||
>>> from transformers import GraniteModel, GraniteConfig
|
||||
|
||||
>>> # Initializing a Granite granite-3b style configuration
|
||||
>>> configuration = GraniteConfig()
|
||||
|
||||
>>> # Initializing a model from the granite-7b style configuration
|
||||
>>> model = GraniteModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "granite"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mlp_bias=False,
|
||||
embedding_multiplier=1.0,
|
||||
logits_scaling=1.0,
|
||||
residual_multiplier=1.0,
|
||||
attention_multiplier=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
|
||||
self.embedding_multiplier = embedding_multiplier
|
||||
self.logits_scaling = logits_scaling
|
||||
self.residual_multiplier = residual_multiplier
|
||||
self.attention_multiplier = attention_multiplier
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
rope_config_validation(self)
|
||||
131
vllm/transformers_utils/configs/qwen2vl.py
Normal file
131
vllm/transformers_utils/configs/qwen2vl.py
Normal file
@ -0,0 +1,131 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen2VL model configuration"""
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class Qwen2VLVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen2_vl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=32,
|
||||
embed_dim=1280,
|
||||
hidden_size=3584,
|
||||
hidden_act="quick_gelu",
|
||||
mlp_ratio=4,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=14,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.embed_dim = embed_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
|
||||
os.PathLike],
|
||||
**kwargs) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if config_dict.get("model_type") == "qwen2_vl":
|
||||
config_dict = config_dict["vision_config"]
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class Qwen2VLConfig(PretrainedConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=152064,
|
||||
hidden_size=8192,
|
||||
intermediate_size=29568,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1000000.0,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
vision_config=None,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = Qwen2VLVisionConfig()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
# NOTE: the following section from original transformers config
|
||||
# for Qwen2-VL is commented out to address rope config loading issue
|
||||
#
|
||||
# if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
# if self.rope_scaling["type"] == "mrope":
|
||||
# self.rope_scaling["type"] = "default"
|
||||
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
# rope_config_validation(self)
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
Loading…
x
Reference in New Issue
Block a user