fix qwen-14b model (#1173)

This commit is contained in:
Qing 2023-09-28 07:33:16 +08:00 committed by GitHub
parent 30e775281d
commit 28e616c4e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 43 deletions

View File

@ -141,17 +141,17 @@ class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig): def __init__(self, config: QWenConfig):
super().__init__() super().__init__()
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
self.attn = QWenAttention(config.n_embd, self.attn = QWenAttention(config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
config.max_position_embeddings, config.max_position_embeddings,
rope_theta=rope_theta) rope_theta=rope_theta)
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2) self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)
def forward( def forward(
self, self,
@ -190,11 +190,11 @@ class QWenModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.wte = VocabParallelEmbedding(vocab_size,
config.n_embd, config.hidden_size,
perform_initialization=False) perform_initialization=False)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[QWenBlock(config) for _ in range(config.num_hidden_layers)]) [QWenBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward( def forward(
self, self,
@ -230,7 +230,7 @@ class QWenLMHeadModel(nn.Module):
self.transformer = QWenModel(config) self.transformer = QWenModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear( self.lm_head = ColumnParallelLinear(
config.n_embd, config.hidden_size,
vocab_size, vocab_size,
bias=False, bias=False,
gather_output=False, gather_output=False,

View File

@ -7,65 +7,54 @@ from transformers import PretrainedConfig
class QWenConfig(PretrainedConfig): class QWenConfig(PretrainedConfig):
model_type = "qwen" model_type = "qwen"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"max_position_embeddings": "n_positions",
"num_hidden_layers": "n_layer",
}
def __init__( def __init__(
self, self,
vocab_size=151851, vocab_size=151936,
n_embd=4096, hidden_size=4096,
n_layer=32, num_hidden_layers=32,
n_head=32, num_attention_heads=32,
n_inner=None, emb_dropout_prob=0.0,
embd_pdrop=0.0, attn_dropout_prob=0.0,
attn_pdrop=0.0, layer_norm_epsilon=1e-6,
layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True, scale_attn_weights=True,
use_cache=True, use_cache=True,
eos_token_id=151643, bf16=False,
apply_residual_connection_post_layernorm=False, fp16=False,
bf16=True, fp32=False,
kv_channels=128, kv_channels=128,
rotary_pct=1.0, rotary_pct=1.0,
rotary_emb_base=10000, rotary_emb_base=10000,
use_dynamic_ntk=False, use_dynamic_ntk=True,
use_logn_attn=False, use_logn_attn=True,
use_flash_attn=True, use_flash_attn="auto",
ffn_hidden_size=22016, intermediate_size=22016,
no_bias=True, no_bias=True,
tie_word_embeddings=False, tie_word_embeddings=False,
**kwargs, **kwargs,
): ):
self.eos_token_id = eos_token_id
super().__init__(eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_embd = n_embd self.hidden_size = hidden_size
self.n_layer = n_layer self.intermediate_size = intermediate_size
self.n_head = n_head self.num_hidden_layers = num_hidden_layers
self.n_inner = n_inner self.num_attention_heads = num_attention_heads
self.embd_pdrop = embd_pdrop self.emb_dropout_prob = emb_dropout_prob
self.attn_pdrop = attn_pdrop self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache self.use_cache = use_cache
self.apply_residual_connection_post_layernorm = ( self.max_position_embeddings = max_position_embeddings
apply_residual_connection_post_layernorm)
self.bf16 = bf16 self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels self.kv_channels = kv_channels
self.rotary_pct = rotary_pct self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn self.use_flash_attn = use_flash_attn
self.ffn_hidden_size = ffn_hidden_size
self.no_bias = no_bias self.no_bias = no_bias
self.tie_word_embeddings = tie_word_embeddings super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)