mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:30:37 +08:00
[Model] Support telechat2 (#10311)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: xiangw2 <xiangw2@chinatelecom.cn> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
e2251109c7
commit
1209261e93
@ -309,6 +309,11 @@ Text Generation
|
||||
- :code:`upstage/solar-pro-preview-instruct`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`TeleChat2ForCausalLM`
|
||||
- TeleChat2
|
||||
- :code:`TeleAI/TeleChat2-3B`, :code:`TeleAI/TeleChat2-7B`, :code:`TeleAI/TeleChat2-35B`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`XverseForCausalLM`
|
||||
- XVERSE
|
||||
- :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc.
|
||||
|
||||
@ -115,6 +115,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
|
||||
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
|
||||
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
||||
trust_remote_code=True),
|
||||
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
|
||||
is_available_online=False,
|
||||
trust_remote_code=True),
|
||||
|
||||
@ -501,8 +501,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = self._init_model(vllm_config=vllm_config, prefix=prefix)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
@ -539,6 +538,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
|
||||
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
|
||||
@ -91,6 +91,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
||||
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||
# [Encoder-decoder]
|
||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||
@ -118,6 +119,7 @@ _EMBEDDING_MODELS = {
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
|
||||
131
vllm/model_executor/models/telechat2.py
Normal file
131
vllm/model_executor/models/telechat2.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
is_pp_missing_parameter)
|
||||
|
||||
|
||||
class TeleChat2Model(LlamaModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# 1. Initialize the LlamaModel with bias
|
||||
vllm_config.model_config.hf_config.bias = True
|
||||
vllm_config.model_config.hf_config.mlp_bias = True
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
# 2. Remove the bias from the qkv_proj and gate_up_proj based on config
|
||||
# Telechat2's gate_up_proj and qkv_proj don't have bias
|
||||
# see: https://github.com/vllm-project/vllm/pull/10311#issuecomment-2490297566
|
||||
for layer in self.layers:
|
||||
if not isinstance(layer, PPMissingLayer):
|
||||
layer.self_attn.qkv_proj.bias = None
|
||||
layer.self_attn.qkv_proj.skip_bias_add = True
|
||||
layer.mlp.gate_up_proj.bias = None
|
||||
layer.mlp.gate_up_proj.skip_bias_add = True
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
('gate_up_proj', 'gate_proj', 0),
|
||||
('gate_up_proj', 'up_proj', 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
total_num_heads = self.config.n_head
|
||||
head_dim = self.config.hidden_size // total_num_heads
|
||||
for name, loaded_weight in weights:
|
||||
if "self_attn.key_value" in name:
|
||||
k_weight = []
|
||||
v_weight = []
|
||||
for i in range(total_num_heads):
|
||||
start = i * head_dim * 2
|
||||
k_weight.append(loaded_weight[start:start + head_dim, :])
|
||||
v_weight.append(loaded_weight[start + head_dim:start +
|
||||
2 * head_dim:])
|
||||
k_weight = torch.cat(k_weight, dim=0)
|
||||
v_weight = torch.cat(v_weight, dim=0)
|
||||
name = name.replace("key_value", "qkv_proj")
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, k_weight, "k")
|
||||
weight_loader(param, v_weight, "v")
|
||||
elif "query" in name:
|
||||
name = name.replace("query", "qkv_proj")
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, "q")
|
||||
else:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class TeleChat2ForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"transformer.": "model.",
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".h.": ".layers.",
|
||||
".self_attention.": ".self_attn.",
|
||||
".word_embeddings.": ".embed_tokens.",
|
||||
".dense.": ".o_proj.",
|
||||
".ln_f.": ".norm.",
|
||||
},
|
||||
)
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
|
||||
@ -29,7 +29,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
NemotronConfig, NVLM_D_Config,
|
||||
Olmo2Config, RWConfig,
|
||||
SolarConfig, UltravoxConfig)
|
||||
SolarConfig, Telechat2Config,
|
||||
UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
@ -64,6 +65,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"NVLM_D": NVLM_D_Config,
|
||||
"olmo2": Olmo2Config,
|
||||
"solar": SolarConfig,
|
||||
"telechat": Telechat2Config,
|
||||
"ultravox": UltravoxConfig,
|
||||
**_CONFIG_REGISTRY_OVERRIDE_HF
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||
from vllm.transformers_utils.configs.olmo2 import Olmo2Config
|
||||
from vllm.transformers_utils.configs.solar import SolarConfig
|
||||
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
|
||||
__all__ = [
|
||||
@ -36,5 +37,6 @@ __all__ = [
|
||||
"NVLM_D_Config",
|
||||
"Olmo2Config",
|
||||
"SolarConfig",
|
||||
"Telechat2Config",
|
||||
"UltravoxConfig",
|
||||
]
|
||||
61
vllm/transformers_utils/configs/telechat2.py
Normal file
61
vllm/transformers_utils/configs/telechat2.py
Normal file
@ -0,0 +1,61 @@
|
||||
# adapted from https://www.modelscope.cn/models/TeleAI/TeleChat2-3B/resolve/master/configuration_telechat2.py
|
||||
""" Telechat configuration compatible with LlamaConfig. """
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class Telechat2Config(PretrainedConfig):
|
||||
|
||||
model_type = "telechat"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"num_hidden_layers": "n_layer",
|
||||
"num_attention_heads": "n_head",
|
||||
"intermediate_size": "ffn_hidden_size",
|
||||
"rms_norm_eps": "layer_norm_epsilon"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=160256,
|
||||
hidden_size=4096,
|
||||
n_layer=30,
|
||||
n_head=32,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
ffn_hidden_size=12288,
|
||||
training_seqlen=8192,
|
||||
logn=True,
|
||||
embed_layernorm=False,
|
||||
hidden_act="silu",
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
n_embed = kwargs.pop("n_embed", None)
|
||||
self.hidden_size = hidden_size if n_embed is None else n_embed
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
apply_residual_connection_post_layernorm)
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.logn = logn
|
||||
self.training_seqlen = training_seqlen
|
||||
self.embed_layernorm = embed_layernorm
|
||||
self.num_key_value_heads = kwargs.pop("num_key_value_heads", None)
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
super().__init__(bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs)
|
||||
Loading…
x
Reference in New Issue
Block a user