[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao 2025-11-07 19:38:38 +08:00 committed by GitHub
parent 7bdb42b2f2
commit 1958bda9b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 190 additions and 25 deletions

View File

@ -75,7 +75,11 @@ class ArcticMLP(nn.Module):
)
self.w13 = MergedColumnParallelLinear(
self.hidden_size, [self.ffn_dim] * 2, bias=False, quant_config=quant_config
self.hidden_size,
[self.ffn_dim] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.w13",
)
self.w2 = RowParallelLinear(
self.ffn_dim,
@ -83,6 +87,7 @@ class ArcticMLP(nn.Module):
bias=False,
reduce_results=reduce_results,
quant_config=quant_config,
prefix=f"{prefix}.w2",
)
if config.hidden_act != "silu":
raise ValueError(
@ -297,6 +302,7 @@ class ArcticAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
@ -304,6 +310,7 @@ class ArcticAttention(nn.Module):
bias=False,
reduce_results=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(

View File

@ -98,13 +98,22 @@ class BaiChuanMLP(nn.Module):
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
@ -152,12 +161,14 @@ class BaiChuanAttention(nn.Module):
self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.W_pack",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# Create the alibi slopes and slice them.
if self.position_embedding == "ALIBI":
@ -235,6 +246,7 @@ class BaiChuanDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(

View File

@ -60,6 +60,7 @@ class BambaMLP(nn.Module):
config: BambaConfig,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@ -67,12 +68,14 @@ class BambaMLP(nn.Module):
output_sizes=[config.intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if config.hidden_act != "silu":
raise ValueError(
@ -118,7 +121,9 @@ class BambaMixerDecoderLayer(nn.Module):
prefix=f"{prefix}.mixer",
)
self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.feed_forward = BambaMLP(
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -202,12 +207,14 @@ class BambaAttentionDecoderLayer(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = Attention(
@ -219,7 +226,9 @@ class BambaAttentionDecoderLayer(nn.Module):
prefix=f"{prefix}.attn",
)
self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.feed_forward = BambaMLP(
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -108,12 +108,14 @@ class BloomAttention(nn.Module):
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
# Create the alibi slopes and slice them.
@ -152,6 +154,7 @@ class BloomMLP(nn.Module):
self,
config: BloomConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
@ -159,12 +162,14 @@ class BloomMLP(nn.Module):
hidden_size,
4 * hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.gelu_impl = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -192,7 +197,7 @@ class BloomBlock(nn.Module):
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon
)
self.mlp = BloomMLP(config, quant_config)
self.mlp = BloomMLP(config, quant_config, prefix=f"{prefix}.mlp")
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm
)

View File

@ -227,6 +227,7 @@ class ChameleonMLP(nn.Module):
hidden_act: str,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@ -234,12 +235,14 @@ class ChameleonMLP(nn.Module):
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
@ -299,12 +302,14 @@ class ChameleonAttention(nn.Module):
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
@ -393,6 +398,7 @@ class ChameleonDecoderLayer(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
@ -462,6 +468,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(

View File

@ -209,12 +209,14 @@ class DbrxAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.Wqkv",
)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.rotary_emb = get_rope(
self.head_dim,

View File

@ -82,7 +82,11 @@ class DeepseekMLP(nn.Module):
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
@ -90,6 +94,7 @@ class DeepseekMLP(nn.Module):
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
@ -239,6 +244,7 @@ class DeepseekAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
@ -246,6 +252,7 @@ class DeepseekAttention(nn.Module):
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(

View File

@ -240,6 +240,7 @@ class Dots1Attention(nn.Module):
self.total_num_kv_heads,
bias=attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
@ -247,6 +248,7 @@ class Dots1Attention(nn.Module):
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(

View File

@ -137,6 +137,7 @@ class FalconAttention(nn.Module):
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
@ -153,6 +154,7 @@ class FalconAttention(nn.Module):
skip_bias_add=True,
quant_config=quant_config,
reduce_results=self.reduce_row_parallel_results,
prefix=f"{prefix}.dense",
)
self.use_rotary = config.rotary
@ -227,6 +229,7 @@ class FalconMLP(nn.Module):
self,
config: FalconConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
@ -237,6 +240,7 @@ class FalconMLP(nn.Module):
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.act = get_act_fn("gelu")
self.reduce_row_parallel_results = not (
@ -249,6 +253,7 @@ class FalconMLP(nn.Module):
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -275,7 +280,7 @@ class FalconDecoderLayer(nn.Module):
self.self_attention = FalconAttention(
config, cache_config, quant_config, prefix=f"{prefix}.self_attention"
)
self.mlp = FalconMLP(config, quant_config)
self.mlp = FalconMLP(config, quant_config, prefix=f"{prefix}.mlp")
self.config = config
if not hasattr(config, "num_ln_in_parallel_attn"):

View File

@ -59,6 +59,7 @@ class FalconH1MLP(nn.Module):
config: FalconH1Config,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@ -66,12 +67,14 @@ class FalconH1MLP(nn.Module):
output_sizes=[config.intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.intermediate_size = config.intermediate_size
@ -365,7 +368,7 @@ class FalconH1ParallelHybrid(nn.Module):
self.attention_in_multiplier = config.attention_in_multiplier
self.attn_out_multiplier = config.attention_out_multiplier
self.feed_forward = FalconH1MLP(config)
self.feed_forward = FalconH1MLP(config, prefix=f"{prefix}.feed_forward")
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -66,13 +66,22 @@ class Gemma2MLP(nn.Module):
hidden_act: str,
hidden_activation: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
raise ValueError(
@ -134,12 +143,14 @@ class Gemma2Attention(nn.Module):
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
@ -208,6 +219,7 @@ class Gemma2DecoderLayer(nn.Module):
hidden_act=config.hidden_act,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(

View File

@ -78,12 +78,14 @@ class GPTJAttention(nn.Module):
self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
tp_world_size = get_tensor_model_parallel_world_size()
@ -130,6 +132,7 @@ class GPTJMLP(nn.Module):
intermediate_size: int,
config: GPTJConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.n_embd
@ -137,11 +140,13 @@ class GPTJMLP(nn.Module):
hidden_size,
intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc_in",
)
self.fc_out = RowParallelLinear(
intermediate_size,
hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc_out",
)
self.act = get_act_fn(config.activation_function)
@ -166,7 +171,7 @@ class GPTJBlock(nn.Module):
self.attn = GPTJAttention(
config, cache_config, quant_config, prefix=f"{prefix}.attn"
)
self.mlp = GPTJMLP(inner_dim, config, quant_config)
self.mlp = GPTJMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,

View File

@ -80,12 +80,14 @@ class GPTNeoXAttention(nn.Module):
self.total_num_heads,
bias=self.bias,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=self.bias,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
@ -125,17 +127,20 @@ class GPTNeoXMLP(nn.Module):
self,
config: GPTNeoXConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
self.act = get_act_fn(config.hidden_act)

View File

@ -107,12 +107,14 @@ class JAISAttention(nn.Module):
total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_attn",
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
tp_rank = get_tensor_model_parallel_rank()
@ -147,6 +149,7 @@ class JAISMLP(nn.Module):
intermediate_size: int,
config: JAISConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
@ -156,6 +159,7 @@ class JAISMLP(nn.Module):
intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_fc",
)
self.c_fc2 = (
ColumnParallelLinear(
@ -163,6 +167,7 @@ class JAISMLP(nn.Module):
intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_fc2",
)
if self.swiglu
else None
@ -172,6 +177,7 @@ class JAISMLP(nn.Module):
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.act = SwiGLUActivation()
@ -206,7 +212,7 @@ class JAISBlock(nn.Module):
config, cache_config, quant_config, prefix=f"{prefix}.attn"
)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, quant_config)
self.mlp = JAISMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,

View File

@ -220,12 +220,14 @@ class JambaAttentionDecoderLayer(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = Attention(

View File

@ -191,13 +191,22 @@ class MiniCPMMLP(nn.Module):
hidden_act: str,
hidden_act_param: float,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act == "silu":
self.act_fn = SiluAndMul()
@ -259,12 +268,14 @@ class MiniCPMAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(

View File

@ -96,6 +96,7 @@ class MiniCPM3Attention(nn.Module):
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
@ -103,6 +104,7 @@ class MiniCPM3Attention(nn.Module):
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
@ -110,6 +112,7 @@ class MiniCPM3Attention(nn.Module):
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
)
# O projection.
self.o_proj = RowParallelLinear(
@ -117,6 +120,7 @@ class MiniCPM3Attention(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(

View File

@ -83,6 +83,7 @@ class MPTAttention(nn.Module):
self.total_num_kv_heads,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.Wqkv",
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
@ -92,6 +93,7 @@ class MPTAttention(nn.Module):
self.d_model,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
tp_world_size = get_tensor_model_parallel_world_size()
@ -152,6 +154,7 @@ class MPTMLP(nn.Module):
self,
config: MptConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.d_model
@ -162,6 +165,7 @@ class MPTMLP(nn.Module):
intermediate_size,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(
@ -169,6 +173,7 @@ class MPTMLP(nn.Module):
hidden_size,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -193,7 +198,7 @@ class MPTBlock(nn.Module):
config, cache_config, quant_config, prefix=f"{prefix}.attn"
)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config)
self.ffn = MPTMLP(config, quant_config, prefix=f"{prefix}.ffn")
def forward(
self,

View File

@ -158,6 +158,7 @@ class OlmoeAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.tp_size = tp_size
self.tp_rank = get_tensor_model_parallel_rank()
@ -168,6 +169,7 @@ class OlmoeAttention(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(

View File

@ -52,13 +52,22 @@ class OrionMLP(nn.Module):
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
@ -116,12 +125,14 @@ class OrionAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
@ -183,6 +194,7 @@ class OrionDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -62,14 +62,23 @@ from .utils import (
class PersimmonMLP(nn.Module):
def __init__(
self, config: PersimmonConfig, quant_config: QuantizationConfig | None = None
self,
config: PersimmonConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size, config.intermediate_size, quant_config=quant_config
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size, config.hidden_size, quant_config=quant_config
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
self.act = get_act_fn(config.hidden_act)
@ -110,12 +119,14 @@ class PersimmonAttention(nn.Module):
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
self.is_qk_layernorm = config.qk_layernorm
@ -192,7 +203,11 @@ class PersimmonDecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = PersimmonMLP(config, quant_config=quant_config)
self.mlp = PersimmonMLP(
config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)

View File

@ -99,11 +99,13 @@ class PhiAttention(nn.Module):
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
scaling = self.head_size**-0.5
@ -148,7 +150,10 @@ class PhiAttention(nn.Module):
class PhiMLP(nn.Module):
def __init__(
self, config: PhiConfig, quant_config: QuantizationConfig | None = None
self,
config: PhiConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@ -159,11 +164,13 @@ class PhiMLP(nn.Module):
config.hidden_size,
n_inner,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
n_inner,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
self.act = get_act_fn(config.hidden_act)
@ -189,7 +196,7 @@ class PhiLayer(nn.Module):
self.self_attn = PhiAttention(
config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
)
self.mlp = PhiMLP(config, quant_config)
self.mlp = PhiMLP(config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,

View File

@ -343,12 +343,14 @@ class PhiMoEAttention(nn.Module):
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,

View File

@ -567,12 +567,14 @@ class Plamo2AttentionMixer(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000

View File

@ -102,12 +102,14 @@ class QWenAttention(nn.Module):
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_attn",
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.scaling = self.head_dim**-0.5

View File

@ -75,7 +75,12 @@ class Zamba2LoRA(nn.Module):
super().__init__()
self.A = ColumnParallelLinear(
input_dim, rank, bias=False, quant_config=quant_config, gather_output=True
input_dim,
rank,
bias=False,
quant_config=quant_config,
gather_output=True,
prefix=f"{prefix}.A",
)
if isinstance(output_dim, list):
@ -150,12 +155,14 @@ class Zamba2Attention(nn.Module):
self.total_num_attention_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.attention_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# Even though in Zamba2 weights are shared between attention layers, KV
@ -197,18 +204,21 @@ class Zamba2Attention(nn.Module):
config.adapter_rank,
self.attention_hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_q_adapter",
)
linear_k_adapter = Zamba2LoRA(
self.attention_hidden_size,
config.adapter_rank,
self.attention_hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_k_adapter",
)
linear_v_adapter = Zamba2LoRA(
self.attention_hidden_size,
config.adapter_rank,
self.attention_hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_v_adapter",
)
else:
linear_q_adapter = nn.Identity()
@ -312,6 +322,7 @@ class Zamba2MLP(nn.Module):
2 * [self.intermediate_size], # 2x for gate and input projections
bias=self.config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
@ -319,6 +330,7 @@ class Zamba2MLP(nn.Module):
self.hidden_size,
bias=self.config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
# Only allow GELU activations
@ -418,6 +430,7 @@ class Zamba2AttentionDecoderLayer(nn.Module):
bare_block_idx=bare_block_idx,
num_hybrid_layers=num_hybrid_layers,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
# Initialize layer normalizations
@ -599,6 +612,7 @@ class Zamba2HybridLayer(nn.Module):
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.linear",
)
self.mamba_decoder = Zamba2MambaDecoderLayer(
config,