diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 9dcfd968b45ca..443f24790674b 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -28,8 +28,8 @@ _MODEL_REGISTRY = { "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* "MistralForCausalLM": MistralForCausalLM, # transformers's mpt class has lower case - "MptForCausalLM": MPTForCausalLM, - "MPTForCausalLM": MPTForCausalLM, + "MptForCausalLM": MptForCausalLM, + "MPTForCausalLM": MptForCausalLM, "OPTForCausalLM": OPTForCausalLM, "QWenLMHeadModel": QWenLMHeadModel, "RWForCausalLM": FalconForCausalLM, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 35d72c16307b4..c4bba4855ef33 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -10,7 +10,7 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.internlm import InternLMForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.mistral import MistralForCausalLM -from vllm.model_executor.models.mpt import MPTForCausalLM +from vllm.model_executor.models.mpt import MptForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel @@ -26,7 +26,7 @@ __all__ = [ "GPTNeoXForCausalLM", "InternLMForCausalLM", "LlamaForCausalLM", - "MPTForCausalLM", + "MptForCausalLM", "OPTForCausalLM", "QWenLMHeadModel", "MistralForCausalLM", diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index ba7441e145b16..4a66c5b5dec6c 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -5,6 +5,7 @@ from typing import List, Optional, Tuple import torch import torch.nn as nn +from transformers import MptConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn @@ -19,7 +20,6 @@ from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) from vllm.sequence import SamplerOutput -from vllm.transformers_utils.configs.mpt import MPTConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -37,17 +37,17 @@ def _get_alibi_slopes( return slopes -class MPTAttention(nn.Module): +class MptAttention(nn.Module): - def __init__(self, config: MPTConfig): + def __init__(self, config: MptConfig): super().__init__() self.d_model = config.d_model self.total_num_heads = config.n_heads - self.clip_qkv = config.attn_config["clip_qkv"] - self.qk_ln = config.attn_config["qk_ln"] - self.alibi_bias_max = config.attn_config["alibi_bias_max"] - assert not config.attn_config["prefix_lm"] - assert config.attn_config["alibi"] + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln + self.alibi_bias_max = config.attn_config.alibi_bias_max + assert not config.attn_config.prefix_lm + assert config.attn_config.alibi self.qkv_proj = ColumnParallelLinear( self.d_model, @@ -105,9 +105,9 @@ class MPTAttention(nn.Module): return output -class MPTMLP(nn.Module): +class MptMLP(nn.Module): - def __init__(self, config: MPTConfig): + def __init__(self, config: MptConfig): super().__init__() hidden_size = config.d_model expansion_ratio = config.expansion_ratio @@ -133,15 +133,15 @@ class MPTMLP(nn.Module): return x -class MPTBlock(nn.Module): +class MptBlock(nn.Module): - def __init__(self, config: MPTConfig): + def __init__(self, config: MptConfig): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config) + self.attn = MptAttention(config) self.norm_2 = nn.LayerNorm(hidden_size) - self.ffn = MPTMLP(config) + self.ffn = MptMLP(config) def forward( self, @@ -166,9 +166,9 @@ class MPTBlock(nn.Module): return hidden_states -class MPTModel(nn.Module): +class MptModel(nn.Module): - def __init__(self, config: MPTConfig): + def __init__(self, config: MptConfig): super().__init__() assert config.embedding_fraction == 1.0 assert config.norm_type == "low_precision_layernorm" @@ -178,7 +178,7 @@ class MPTModel(nn.Module): config.d_model, ) self.blocks = nn.ModuleList( - [MPTBlock(config) for _ in range(config.n_layers)]) + [MptBlock(config) for _ in range(config.n_layers)]) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -213,14 +213,14 @@ class MPTModel(nn.Module): return hidden_states -class MPTForCausalLM(nn.Module): +class MptForCausalLM(nn.Module): - def __init__(self, config: MPTConfig): + def __init__(self, config: MptConfig): super().__init__() self.config = config assert config.tie_word_embeddings - self.transformer = MPTModel(config) + self.transformer = MptModel(config) # TODO(zhuohan): create a new weight after implementing pipeline # parallelism self.lm_head_weight = self.transformer.wte.weight diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fd5618bd81ba1..b69e0a1a43850 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,11 +1,11 @@ from typing import Optional -from transformers import AutoConfig, PretrainedConfig +from transformers import AutoConfig, MptConfig, PretrainedConfig from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import _CONFIG_REGISTRY = { - "mpt": MPTConfig, + "mpt": MptConfig, "baichuan": BaiChuanConfig, "aquila": AquilaConfig, "qwen": QWenConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 6611697d25ae3..f5acb4e07972e 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,4 +1,3 @@ -from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.baichuan import BaiChuanConfig from vllm.transformers_utils.configs.aquila import AquilaConfig from vllm.transformers_utils.configs.qwen import QWenConfig @@ -8,7 +7,6 @@ from vllm.transformers_utils.configs.qwen import QWenConfig from vllm.transformers_utils.configs.falcon import RWConfig __all__ = [ - "MPTConfig", "BaiChuanConfig", "AquilaConfig", "QWenConfig", diff --git a/vllm/transformers_utils/configs/mpt.py b/vllm/transformers_utils/configs/mpt.py deleted file mode 100644 index 3909f710d44de..0000000000000 --- a/vllm/transformers_utils/configs/mpt.py +++ /dev/null @@ -1,74 +0,0 @@ -# Adapted from -# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py -from typing import Any, Dict, Optional, Union - -from transformers import PretrainedConfig - -_ATTN_CONFIG_DEFAULTS = { - "attn_type": "multihead_attention", - "attn_pdrop": 0.0, - "attn_impl": "triton", - "qk_ln": False, - "clip_qkv": None, - "softmax_scale": None, - "prefix_lm": False, - "attn_uses_sequence_id": False, - "alibi": False, - "alibi_bias_max": 8, -} - - -class MPTConfig(PretrainedConfig): - model_type = "mpt" - attribute_map = { - "hidden_size": "d_model", - "num_attention_heads": "n_heads", - "num_hidden_layers": "n_layers", - } - - def __init__( - self, - d_model: int = 2048, - n_heads: int = 16, - n_layers: int = 24, - expansion_ratio: int = 4, - max_seq_len: int = 2048, - vocab_size: int = 50368, - resid_pdrop: float = 0.0, - emb_pdrop: float = 0.0, - learned_pos_emb: bool = True, - attn_config: Optional[Dict[str, Any]] = None, - init_device: str = "cpu", - logit_scale: Optional[Union[float, str]] = None, - no_bias: bool = False, - verbose: int = 0, - embedding_fraction: float = 1.0, - norm_type: str = "low_precision_layernorm", - use_cache: bool = False, - **kwargs, - ) -> None: - self.d_model = d_model - self.n_heads = n_heads - self.n_layers = n_layers - self.expansion_ratio = expansion_ratio - self.max_seq_len = max_seq_len - self.vocab_size = vocab_size - self.resid_pdrop = resid_pdrop - self.emb_pdrop = emb_pdrop - self.learned_pos_emb = learned_pos_emb - if attn_config is None: - self.attn_config = _ATTN_CONFIG_DEFAULTS - else: - self.attn_config = attn_config - self.init_device = init_device - self.logit_scale = logit_scale - self.no_bias = no_bias - self.verbose = verbose - self.embedding_fraction = embedding_fraction - self.norm_type = norm_type - self.use_cache = use_cache - if "name" in kwargs: - del kwargs["name"] - if "loss_fn" in kwargs: - del kwargs["loss_fn"] - super().__init__(**kwargs)