diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index f572edb41db8..184d5a574e6a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -141,17 +141,17 @@ class QWenBlock(nn.Module): def __init__(self, config: QWenConfig): super().__init__() - self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) rope_theta = getattr(config, "rope_theta", 10000) - self.attn = QWenAttention(config.n_embd, + self.attn = QWenAttention(config.hidden_size, config.num_attention_heads, config.max_position_embeddings, rope_theta=rope_theta) - self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2) + self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2) def forward( self, @@ -190,11 +190,11 @@ class QWenModel(nn.Module): vocab_size = ((config.vocab_size + 63) // 64) * 64 self.wte = VocabParallelEmbedding(vocab_size, - config.n_embd, + config.hidden_size, perform_initialization=False) self.h = nn.ModuleList( [QWenBlock(config) for _ in range(config.num_hidden_layers)]) - self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, @@ -230,7 +230,7 @@ class QWenLMHeadModel(nn.Module): self.transformer = QWenModel(config) vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ColumnParallelLinear( - config.n_embd, + config.hidden_size, vocab_size, bias=False, gather_output=False, diff --git a/vllm/transformers_utils/configs/qwen.py b/vllm/transformers_utils/configs/qwen.py index 916bb4c77bc0..bb033a337ad0 100644 --- a/vllm/transformers_utils/configs/qwen.py +++ b/vllm/transformers_utils/configs/qwen.py @@ -7,65 +7,54 @@ from transformers import PretrainedConfig class QWenConfig(PretrainedConfig): model_type = "qwen" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = { - "hidden_size": "n_embd", - "num_attention_heads": "n_head", - "max_position_embeddings": "n_positions", - "num_hidden_layers": "n_layer", - } def __init__( self, - vocab_size=151851, - n_embd=4096, - n_layer=32, - n_head=32, - n_inner=None, - embd_pdrop=0.0, - attn_pdrop=0.0, - layer_norm_epsilon=1e-5, + vocab_size=151936, + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + emb_dropout_prob=0.0, + attn_dropout_prob=0.0, + layer_norm_epsilon=1e-6, initializer_range=0.02, + max_position_embeddings=8192, scale_attn_weights=True, use_cache=True, - eos_token_id=151643, - apply_residual_connection_post_layernorm=False, - bf16=True, + bf16=False, + fp16=False, + fp32=False, kv_channels=128, rotary_pct=1.0, rotary_emb_base=10000, - use_dynamic_ntk=False, - use_logn_attn=False, - use_flash_attn=True, - ffn_hidden_size=22016, + use_dynamic_ntk=True, + use_logn_attn=True, + use_flash_attn="auto", + intermediate_size=22016, no_bias=True, tie_word_embeddings=False, **kwargs, ): - self.eos_token_id = eos_token_id - super().__init__(eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs) - self.vocab_size = vocab_size - self.n_embd = n_embd - self.n_layer = n_layer - self.n_head = n_head - self.n_inner = n_inner - self.embd_pdrop = embd_pdrop - self.attn_pdrop = attn_pdrop + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.emb_dropout_prob = emb_dropout_prob + self.attn_dropout_prob = attn_dropout_prob self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache - self.apply_residual_connection_post_layernorm = ( - apply_residual_connection_post_layernorm) + self.max_position_embeddings = max_position_embeddings self.bf16 = bf16 + self.fp16 = fp16 + self.fp32 = fp32 self.kv_channels = kv_channels self.rotary_pct = rotary_pct self.rotary_emb_base = rotary_emb_base self.use_dynamic_ntk = use_dynamic_ntk self.use_logn_attn = use_logn_attn self.use_flash_attn = use_flash_attn - self.ffn_hidden_size = ffn_hidden_size self.no_bias = no_bias - self.tie_word_embeddings = tie_word_embeddings + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)