[Minor] Remove unused config files (#3039)

This commit is contained in:
Roy 2024-02-27 09:25:22 +08:00 committed by GitHub
parent d6e4a130b0
commit d9f726c4d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 10 additions and 210 deletions

View File

@ -23,6 +23,7 @@ from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
@ -42,7 +43,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -186,7 +186,7 @@ class BaiChuanAttention(nn.Module):
class BaiChuanDecoderLayer(nn.Module):
def __init__(self,
config: BaiChuanConfig,
config: PretrainedConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
@ -245,7 +245,7 @@ class BaiChuanDecoderLayer(nn.Module):
class BaiChuanModel(nn.Module):
def __init__(self,
config: BaiChuanConfig,
config: PretrainedConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()

View File

@ -61,7 +61,9 @@ from vllm.model_executor.weight_utils import (
hf_model_weights_iterator,
)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.olmo import OLMoConfig
# this model must need this dependency
from hf_olmo import OLMoConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
@ -27,7 +28,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.qwen import QWenConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -127,7 +127,7 @@ class QWenBlock(nn.Module):
def __init__(
self,
config: QWenConfig,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
@ -179,7 +179,7 @@ class QWenModel(nn.Module):
def __init__(
self,
config: QWenConfig,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
@ -222,7 +222,7 @@ class QWenLMHeadModel(nn.Module):
def __init__(
self,
config: QWenConfig,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()

View File

@ -5,10 +5,8 @@ from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import *
_CONFIG_REGISTRY = {
"baichuan": BaiChuanConfig,
"chatglm": ChatGLMConfig,
"mpt": MPTConfig,
"qwen": QWenConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
}

View File

@ -1,18 +1,12 @@
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.olmo import OLMoConfig
from vllm.transformers_utils.configs.qwen import QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
__all__ = [
"BaiChuanConfig",
"ChatGLMConfig",
"MPTConfig",
"OLMoConfig",
"QWenConfig",
"RWConfig",
]

View File

@ -1,62 +0,0 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.configuration_utils import PretrainedConfig
class BaiChuanConfig(PretrainedConfig):
model_type = "baichuan"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=64000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -1,72 +0,0 @@
# coding=utf-8
# adapted from https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/configuration_olmo.py
"""OLMo configuration"""
from transformers import PretrainedConfig
class OLMoConfig(PretrainedConfig):
model_type = 'olmo'
attribute_map = {
'num_attention_heads': 'n_heads',
'hidden_size': 'd_model',
'num_hidden_layers': 'n_layers',
}
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
def __init__(
self,
d_model=768,
n_heads=12,
n_layers=12,
mlp_ratio=4,
mlp_hidden_size=None,
activation_type="swiglu",
block_type="sequential",
block_group_size=1,
alibi=False,
alibi_bias_max=8.0,
rope=False,
rope_full_precision=True,
multi_query_attention=False,
attention_layer_norm=False,
layer_norm_type="default",
layer_norm_with_affine=True,
attention_layer_norm_with_affine=True,
max_sequence_length=1024,
include_bias=True,
bias_for_layer_norm=None,
scale_logits=False,
vocab_size=50257,
embedding_size=50304,
weight_tying=True,
eos_token_id=50256,
pad_token_id=50256,
**kwargs,
):
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.mlp_ratio = mlp_ratio
self.mlp_hidden_size = mlp_hidden_size
self.activation_type = activation_type
self.block_type = block_type
self.block_group_size = block_group_size
self.alibi = alibi
self.alibi_bias_max = alibi_bias_max
self.rope = rope
self.rope_full_precision = rope_full_precision
self.multi_query_attention = multi_query_attention
self.attention_layer_norm = attention_layer_norm
self.layer_norm_type = layer_norm_type
self.layer_norm_with_affine = layer_norm_with_affine
self.attention_layer_norm_with_affine = attention_layer_norm_with_affine
self.max_sequence_length = max_sequence_length
self.include_bias = include_bias
self.bias_for_layer_norm = bias_for_layer_norm
self.scale_logits = scale_logits
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.weight_tying = weight_tying
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
super().__init__(**kwargs)

View File

@ -1,60 +0,0 @@
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
from transformers import PretrainedConfig
class QWenConfig(PretrainedConfig):
model_type = "qwen"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
emb_dropout_prob=0.0,
attn_dropout_prob=0.0,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True,
use_cache=True,
bf16=False,
fp16=False,
fp32=False,
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
use_dynamic_ntk=True,
use_logn_attn=True,
use_flash_attn="auto",
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.emb_dropout_prob = emb_dropout_prob
self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.no_bias = no_bias
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)