mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
parent
7bdb42b2f2
commit
1958bda9b4
@ -75,7 +75,11 @@ class ArcticMLP(nn.Module):
|
||||
)
|
||||
|
||||
self.w13 = MergedColumnParallelLinear(
|
||||
self.hidden_size, [self.ffn_dim] * 2, bias=False, quant_config=quant_config
|
||||
self.hidden_size,
|
||||
[self.ffn_dim] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w13",
|
||||
)
|
||||
self.w2 = RowParallelLinear(
|
||||
self.ffn_dim,
|
||||
@ -83,6 +87,7 @@ class ArcticMLP(nn.Module):
|
||||
bias=False,
|
||||
reduce_results=reduce_results,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w2",
|
||||
)
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@ -297,6 +302,7 @@ class ArcticAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
@ -304,6 +310,7 @@ class ArcticAttention(nn.Module):
|
||||
bias=False,
|
||||
reduce_results=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
|
||||
@ -98,13 +98,22 @@ class BaiChuanMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@ -152,12 +161,14 @@ class BaiChuanAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.position_embedding == "ALIBI":
|
||||
@ -235,6 +246,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
|
||||
@ -60,6 +60,7 @@ class BambaMLP(nn.Module):
|
||||
config: BambaConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@ -67,12 +68,14 @@ class BambaMLP(nn.Module):
|
||||
output_sizes=[config.intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@ -118,7 +121,9 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
prefix=f"{prefix}.mixer",
|
||||
)
|
||||
|
||||
self.feed_forward = BambaMLP(config, quant_config=quant_config)
|
||||
self.feed_forward = BambaMLP(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -202,12 +207,14 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
@ -219,7 +226,9 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
self.feed_forward = BambaMLP(config, quant_config=quant_config)
|
||||
self.feed_forward = BambaMLP(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
@ -108,12 +108,14 @@ class BloomAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
# Create the alibi slopes and slice them.
|
||||
@ -152,6 +154,7 @@ class BloomMLP(nn.Module):
|
||||
self,
|
||||
config: BloomConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -159,12 +162,14 @@ class BloomMLP(nn.Module):
|
||||
hidden_size,
|
||||
4 * hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_h_to_4h",
|
||||
)
|
||||
self.gelu_impl = get_act_fn("gelu")
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_4h_to_h",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -192,7 +197,7 @@ class BloomBlock(nn.Module):
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
hidden_size, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.mlp = BloomMLP(config, quant_config)
|
||||
self.mlp = BloomMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm
|
||||
)
|
||||
|
||||
@ -227,6 +227,7 @@ class ChameleonMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@ -234,12 +235,14 @@ class ChameleonMLP(nn.Module):
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@ -299,12 +302,14 @@ class ChameleonAttention(nn.Module):
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
|
||||
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
|
||||
@ -393,6 +398,7 @@ class ChameleonDecoderLayer(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@ -462,6 +468,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
|
||||
@ -209,12 +209,14 @@ class DbrxAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.Wqkv",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.d_model,
|
||||
self.d_model,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
|
||||
@ -82,7 +82,11 @@ class DeepseekMLP(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@ -90,6 +94,7 @@ class DeepseekMLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@ -239,6 +244,7 @@ class DeepseekAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
@ -246,6 +252,7 @@ class DeepseekAttention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
|
||||
@ -240,6 +240,7 @@ class Dots1Attention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
@ -247,6 +248,7 @@ class Dots1Attention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
|
||||
@ -137,6 +137,7 @@ class FalconAttention(nn.Module):
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
@ -153,6 +154,7 @@ class FalconAttention(nn.Module):
|
||||
skip_bias_add=True,
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.reduce_row_parallel_results,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
self.use_rotary = config.rotary
|
||||
@ -227,6 +229,7 @@ class FalconMLP(nn.Module):
|
||||
self,
|
||||
config: FalconConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -237,6 +240,7 @@ class FalconMLP(nn.Module):
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_h_to_4h",
|
||||
)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.reduce_row_parallel_results = not (
|
||||
@ -249,6 +253,7 @@ class FalconMLP(nn.Module):
|
||||
skip_bias_add=True,
|
||||
reduce_results=self.reduce_row_parallel_results,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_4h_to_h",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -275,7 +280,7 @@ class FalconDecoderLayer(nn.Module):
|
||||
self.self_attention = FalconAttention(
|
||||
config, cache_config, quant_config, prefix=f"{prefix}.self_attention"
|
||||
)
|
||||
self.mlp = FalconMLP(config, quant_config)
|
||||
self.mlp = FalconMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||
self.config = config
|
||||
|
||||
if not hasattr(config, "num_ln_in_parallel_attn"):
|
||||
|
||||
@ -59,6 +59,7 @@ class FalconH1MLP(nn.Module):
|
||||
config: FalconH1Config,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@ -66,12 +67,14 @@ class FalconH1MLP(nn.Module):
|
||||
output_sizes=[config.intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.intermediate_size = config.intermediate_size
|
||||
@ -365,7 +368,7 @@ class FalconH1ParallelHybrid(nn.Module):
|
||||
self.attention_in_multiplier = config.attention_in_multiplier
|
||||
self.attn_out_multiplier = config.attention_out_multiplier
|
||||
|
||||
self.feed_forward = FalconH1MLP(config)
|
||||
self.feed_forward = FalconH1MLP(config, prefix=f"{prefix}.feed_forward")
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -66,13 +66,22 @@ class Gemma2MLP(nn.Module):
|
||||
hidden_act: str,
|
||||
hidden_activation: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
|
||||
raise ValueError(
|
||||
@ -134,12 +143,14 @@ class Gemma2Attention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -208,6 +219,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
hidden_activation=config.hidden_activation,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(
|
||||
|
||||
@ -78,12 +78,14 @@ class GPTJAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
@ -130,6 +132,7 @@ class GPTJMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
config: GPTJConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.n_embd
|
||||
@ -137,11 +140,13 @@ class GPTJMLP(nn.Module):
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc_in",
|
||||
)
|
||||
self.fc_out = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc_out",
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
@ -166,7 +171,7 @@ class GPTJBlock(nn.Module):
|
||||
self.attn = GPTJAttention(
|
||||
config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
||||
)
|
||||
self.mlp = GPTJMLP(inner_dim, config, quant_config)
|
||||
self.mlp = GPTJMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -80,12 +80,14 @@ class GPTNeoXAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
bias=self.bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=self.bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
scaling = self.head_size**-0.5
|
||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||
@ -125,17 +127,20 @@ class GPTNeoXMLP(nn.Module):
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_h_to_4h",
|
||||
)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_4h_to_h",
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
|
||||
@ -107,12 +107,14 @@ class JAISAttention(nn.Module):
|
||||
total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_attn",
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
@ -147,6 +149,7 @@ class JAISMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
config: JAISConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -156,6 +159,7 @@ class JAISMLP(nn.Module):
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
)
|
||||
self.c_fc2 = (
|
||||
ColumnParallelLinear(
|
||||
@ -163,6 +167,7 @@ class JAISMLP(nn.Module):
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_fc2",
|
||||
)
|
||||
if self.swiglu
|
||||
else None
|
||||
@ -172,6 +177,7 @@ class JAISMLP(nn.Module):
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
|
||||
self.act = SwiGLUActivation()
|
||||
@ -206,7 +212,7 @@ class JAISBlock(nn.Module):
|
||||
config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = JAISMLP(inner_dim, config, quant_config)
|
||||
self.mlp = JAISMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -220,12 +220,14 @@ class JambaAttentionDecoderLayer(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
|
||||
@ -191,13 +191,22 @@ class MiniCPMMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
hidden_act_param: float,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act == "silu":
|
||||
self.act_fn = SiluAndMul()
|
||||
@ -259,12 +268,14 @@ class MiniCPMAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
|
||||
@ -96,6 +96,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj",
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
@ -103,6 +104,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
@ -110,6 +112,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj",
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
@ -117,6 +120,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
|
||||
@ -83,6 +83,7 @@ class MPTAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=not config.no_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.Wqkv",
|
||||
)
|
||||
if self.qk_ln:
|
||||
self.q_ln = nn.LayerNorm(self.d_model)
|
||||
@ -92,6 +93,7 @@ class MPTAttention(nn.Module):
|
||||
self.d_model,
|
||||
bias=not config.no_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
@ -152,6 +154,7 @@ class MPTMLP(nn.Module):
|
||||
self,
|
||||
config: MptConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.d_model
|
||||
@ -162,6 +165,7 @@ class MPTMLP(nn.Module):
|
||||
intermediate_size,
|
||||
bias=not config.no_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.up_proj",
|
||||
)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.down_proj = RowParallelLinear(
|
||||
@ -169,6 +173,7 @@ class MPTMLP(nn.Module):
|
||||
hidden_size,
|
||||
bias=not config.no_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -193,7 +198,7 @@ class MPTBlock(nn.Module):
|
||||
config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
||||
)
|
||||
self.norm_2 = nn.LayerNorm(hidden_size)
|
||||
self.ffn = MPTMLP(config, quant_config)
|
||||
self.ffn = MPTMLP(config, quant_config, prefix=f"{prefix}.ffn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -158,6 +158,7 @@ class OlmoeAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
@ -168,6 +169,7 @@ class OlmoeAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
|
||||
@ -52,13 +52,22 @@ class OrionMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@ -116,12 +125,14 @@ class OrionAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -183,6 +194,7 @@ class OrionDecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -62,14 +62,23 @@ from .utils import (
|
||||
|
||||
class PersimmonMLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: PersimmonConfig, quant_config: QuantizationConfig | None = None
|
||||
self,
|
||||
config: PersimmonConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
config.hidden_size, config.intermediate_size, quant_config=quant_config
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_h_to_4h",
|
||||
)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.intermediate_size, config.hidden_size, quant_config=quant_config
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_4h_to_h",
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
@ -110,12 +119,14 @@ class PersimmonAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
self.is_qk_layernorm = config.qk_layernorm
|
||||
|
||||
@ -192,7 +203,11 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = PersimmonMLP(config, quant_config=quant_config)
|
||||
self.mlp = PersimmonMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
@ -99,11 +99,13 @@ class PhiAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
scaling = self.head_size**-0.5
|
||||
@ -148,7 +150,10 @@ class PhiAttention(nn.Module):
|
||||
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: PhiConfig, quant_config: QuantizationConfig | None = None
|
||||
self,
|
||||
config: PhiConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -159,11 +164,13 @@ class PhiMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
n_inner,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
n_inner,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
@ -189,7 +196,7 @@ class PhiLayer(nn.Module):
|
||||
self.self_attn = PhiAttention(
|
||||
config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
|
||||
)
|
||||
self.mlp = PhiMLP(config, quant_config)
|
||||
self.mlp = PhiMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -343,12 +343,14 @@ class PhiMoEAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
|
||||
@ -567,12 +567,14 @@ class Plamo2AttentionMixer(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000
|
||||
|
||||
@ -102,12 +102,14 @@ class QWenAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_attn",
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
|
||||
@ -75,7 +75,12 @@ class Zamba2LoRA(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.A = ColumnParallelLinear(
|
||||
input_dim, rank, bias=False, quant_config=quant_config, gather_output=True
|
||||
input_dim,
|
||||
rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
gather_output=True,
|
||||
prefix=f"{prefix}.A",
|
||||
)
|
||||
|
||||
if isinstance(output_dim, list):
|
||||
@ -150,12 +155,14 @@ class Zamba2Attention(nn.Module):
|
||||
self.total_num_attention_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.attention_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
# Even though in Zamba2 weights are shared between attention layers, KV
|
||||
@ -197,18 +204,21 @@ class Zamba2Attention(nn.Module):
|
||||
config.adapter_rank,
|
||||
self.attention_hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_q_adapter",
|
||||
)
|
||||
linear_k_adapter = Zamba2LoRA(
|
||||
self.attention_hidden_size,
|
||||
config.adapter_rank,
|
||||
self.attention_hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_k_adapter",
|
||||
)
|
||||
linear_v_adapter = Zamba2LoRA(
|
||||
self.attention_hidden_size,
|
||||
config.adapter_rank,
|
||||
self.attention_hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_v_adapter",
|
||||
)
|
||||
else:
|
||||
linear_q_adapter = nn.Identity()
|
||||
@ -312,6 +322,7 @@ class Zamba2MLP(nn.Module):
|
||||
2 * [self.intermediate_size], # 2x for gate and input projections
|
||||
bias=self.config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
@ -319,6 +330,7 @@ class Zamba2MLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=self.config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
# Only allow GELU activations
|
||||
@ -418,6 +430,7 @@ class Zamba2AttentionDecoderLayer(nn.Module):
|
||||
bare_block_idx=bare_block_idx,
|
||||
num_hybrid_layers=num_hybrid_layers,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
|
||||
# Initialize layer normalizations
|
||||
@ -599,6 +612,7 @@ class Zamba2HybridLayer(nn.Module):
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear",
|
||||
)
|
||||
self.mamba_decoder = Zamba2MambaDecoderLayer(
|
||||
config,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user