Remove MPTConfig (#1529)

This commit is contained in:
Woosuk Kwon 2023-11-01 15:29:05 -07:00 committed by GitHub
parent 7e90a2d117
commit 1fe0990023
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 26 additions and 102 deletions

View File

@ -28,8 +28,8 @@ _MODEL_REGISTRY = {
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MistralForCausalLM": MistralForCausalLM,
# transformers's mpt class has lower case
"MptForCausalLM": MPTForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"MptForCausalLM": MptForCausalLM,
"MPTForCausalLM": MptForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM,

View File

@ -10,7 +10,7 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from vllm.model_executor.models.internlm import InternLMForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mistral import MistralForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.mpt import MptForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel
@ -26,7 +26,7 @@ __all__ = [
"GPTNeoXForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"MPTForCausalLM",
"MptForCausalLM",
"OPTForCausalLM",
"QWenLMHeadModel",
"MistralForCausalLM",

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from transformers import MptConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
@ -19,7 +20,6 @@ from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -37,17 +37,17 @@ def _get_alibi_slopes(
return slopes
class MPTAttention(nn.Module):
class MptAttention(nn.Module):
def __init__(self, config: MPTConfig):
def __init__(self, config: MptConfig):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]
self.clip_qkv = config.attn_config.clip_qkv
self.qk_ln = config.attn_config.qk_ln
self.alibi_bias_max = config.attn_config.alibi_bias_max
assert not config.attn_config.prefix_lm
assert config.attn_config.alibi
self.qkv_proj = ColumnParallelLinear(
self.d_model,
@ -105,9 +105,9 @@ class MPTAttention(nn.Module):
return output
class MPTMLP(nn.Module):
class MptMLP(nn.Module):
def __init__(self, config: MPTConfig):
def __init__(self, config: MptConfig):
super().__init__()
hidden_size = config.d_model
expansion_ratio = config.expansion_ratio
@ -133,15 +133,15 @@ class MPTMLP(nn.Module):
return x
class MPTBlock(nn.Module):
class MptBlock(nn.Module):
def __init__(self, config: MPTConfig):
def __init__(self, config: MptConfig):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config)
self.attn = MptAttention(config)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config)
self.ffn = MptMLP(config)
def forward(
self,
@ -166,9 +166,9 @@ class MPTBlock(nn.Module):
return hidden_states
class MPTModel(nn.Module):
class MptModel(nn.Module):
def __init__(self, config: MPTConfig):
def __init__(self, config: MptConfig):
super().__init__()
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"
@ -178,7 +178,7 @@ class MPTModel(nn.Module):
config.d_model,
)
self.blocks = nn.ModuleList(
[MPTBlock(config) for _ in range(config.n_layers)])
[MptBlock(config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
@ -213,14 +213,14 @@ class MPTModel(nn.Module):
return hidden_states
class MPTForCausalLM(nn.Module):
class MptForCausalLM(nn.Module):
def __init__(self, config: MPTConfig):
def __init__(self, config: MptConfig):
super().__init__()
self.config = config
assert config.tie_word_embeddings
self.transformer = MPTModel(config)
self.transformer = MptModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.wte.weight

View File

@ -1,11 +1,11 @@
from typing import Optional
from transformers import AutoConfig, PretrainedConfig
from transformers import AutoConfig, MptConfig, PretrainedConfig
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
_CONFIG_REGISTRY = {
"mpt": MPTConfig,
"mpt": MptConfig,
"baichuan": BaiChuanConfig,
"aquila": AquilaConfig,
"qwen": QWenConfig,

View File

@ -1,4 +1,3 @@
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.aquila import AquilaConfig
from vllm.transformers_utils.configs.qwen import QWenConfig
@ -8,7 +7,6 @@ from vllm.transformers_utils.configs.qwen import QWenConfig
from vllm.transformers_utils.configs.falcon import RWConfig
__all__ = [
"MPTConfig",
"BaiChuanConfig",
"AquilaConfig",
"QWenConfig",

View File

@ -1,74 +0,0 @@
# Adapted from
# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py
from typing import Any, Dict, Optional, Union
from transformers import PretrainedConfig
_ATTN_CONFIG_DEFAULTS = {
"attn_type": "multihead_attention",
"attn_pdrop": 0.0,
"attn_impl": "triton",
"qk_ln": False,
"clip_qkv": None,
"softmax_scale": None,
"prefix_lm": False,
"attn_uses_sequence_id": False,
"alibi": False,
"alibi_bias_max": 8,
}
class MPTConfig(PretrainedConfig):
model_type = "mpt"
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "n_heads",
"num_hidden_layers": "n_layers",
}
def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
expansion_ratio: int = 4,
max_seq_len: int = 2048,
vocab_size: int = 50368,
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
learned_pos_emb: bool = True,
attn_config: Optional[Dict[str, Any]] = None,
init_device: str = "cpu",
logit_scale: Optional[Union[float, str]] = None,
no_bias: bool = False,
verbose: int = 0,
embedding_fraction: float = 1.0,
norm_type: str = "low_precision_layernorm",
use_cache: bool = False,
**kwargs,
) -> None:
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.expansion_ratio = expansion_ratio
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.learned_pos_emb = learned_pos_emb
if attn_config is None:
self.attn_config = _ATTN_CONFIG_DEFAULTS
else:
self.attn_config = attn_config
self.init_device = init_device
self.logit_scale = logit_scale
self.no_bias = no_bias
self.verbose = verbose
self.embedding_fraction = embedding_fraction
self.norm_type = norm_type
self.use_cache = use_cache
if "name" in kwargs:
del kwargs["name"]
if "loss_fn" in kwargs:
del kwargs["loss_fn"]
super().__init__(**kwargs)