mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 21:55:38 +08:00
[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
parent
7bdb42b2f2
commit
1958bda9b4
@ -75,7 +75,11 @@ class ArcticMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.w13 = MergedColumnParallelLinear(
|
self.w13 = MergedColumnParallelLinear(
|
||||||
self.hidden_size, [self.ffn_dim] * 2, bias=False, quant_config=quant_config
|
self.hidden_size,
|
||||||
|
[self.ffn_dim] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.w13",
|
||||||
)
|
)
|
||||||
self.w2 = RowParallelLinear(
|
self.w2 = RowParallelLinear(
|
||||||
self.ffn_dim,
|
self.ffn_dim,
|
||||||
@ -83,6 +87,7 @@ class ArcticMLP(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
reduce_results=reduce_results,
|
reduce_results=reduce_results,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.w2",
|
||||||
)
|
)
|
||||||
if config.hidden_act != "silu":
|
if config.hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -297,6 +302,7 @@ class ArcticAttention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
@ -304,6 +310,7 @@ class ArcticAttention(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
|
|||||||
@ -98,13 +98,22 @@ class BaiChuanMLP(nn.Module):
|
|||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -152,12 +161,14 @@ class BaiChuanAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.W_pack",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
# Create the alibi slopes and slice them.
|
# Create the alibi slopes and slice them.
|
||||||
if self.position_embedding == "ALIBI":
|
if self.position_embedding == "ALIBI":
|
||||||
@ -235,6 +246,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
|||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(
|
self.post_attention_layernorm = RMSNorm(
|
||||||
|
|||||||
@ -60,6 +60,7 @@ class BambaMLP(nn.Module):
|
|||||||
config: BambaConfig,
|
config: BambaConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
@ -67,12 +68,14 @@ class BambaMLP(nn.Module):
|
|||||||
output_sizes=[config.intermediate_size] * 2,
|
output_sizes=[config.intermediate_size] * 2,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
input_size=config.intermediate_size,
|
input_size=config.intermediate_size,
|
||||||
output_size=config.hidden_size,
|
output_size=config.hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
)
|
)
|
||||||
if config.hidden_act != "silu":
|
if config.hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -118,7 +121,9 @@ class BambaMixerDecoderLayer(nn.Module):
|
|||||||
prefix=f"{prefix}.mixer",
|
prefix=f"{prefix}.mixer",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward = BambaMLP(config, quant_config=quant_config)
|
self.feed_forward = BambaMLP(
|
||||||
|
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
||||||
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
@ -202,12 +207,14 @@ class BambaAttentionDecoderLayer(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
@ -219,7 +226,9 @@ class BambaAttentionDecoderLayer(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward = BambaMLP(config, quant_config=quant_config)
|
self.feed_forward = BambaMLP(
|
||||||
|
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
||||||
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
|||||||
@ -108,12 +108,14 @@ class BloomAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.query_key_value",
|
||||||
)
|
)
|
||||||
self.dense = RowParallelLinear(
|
self.dense = RowParallelLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the alibi slopes and slice them.
|
# Create the alibi slopes and slice them.
|
||||||
@ -152,6 +154,7 @@ class BloomMLP(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: BloomConfig,
|
config: BloomConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
@ -159,12 +162,14 @@ class BloomMLP(nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
4 * hidden_size,
|
4 * hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense_h_to_4h",
|
||||||
)
|
)
|
||||||
self.gelu_impl = get_act_fn("gelu")
|
self.gelu_impl = get_act_fn("gelu")
|
||||||
self.dense_4h_to_h = RowParallelLinear(
|
self.dense_4h_to_h = RowParallelLinear(
|
||||||
4 * hidden_size,
|
4 * hidden_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense_4h_to_h",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -192,7 +197,7 @@ class BloomBlock(nn.Module):
|
|||||||
self.post_attention_layernorm = nn.LayerNorm(
|
self.post_attention_layernorm = nn.LayerNorm(
|
||||||
hidden_size, eps=config.layer_norm_epsilon
|
hidden_size, eps=config.layer_norm_epsilon
|
||||||
)
|
)
|
||||||
self.mlp = BloomMLP(config, quant_config)
|
self.mlp = BloomMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||||
self.apply_residual_connection_post_layernorm = (
|
self.apply_residual_connection_post_layernorm = (
|
||||||
config.apply_residual_connection_post_layernorm
|
config.apply_residual_connection_post_layernorm
|
||||||
)
|
)
|
||||||
|
|||||||
@ -227,6 +227,7 @@ class ChameleonMLP(nn.Module):
|
|||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
@ -234,12 +235,14 @@ class ChameleonMLP(nn.Module):
|
|||||||
output_sizes=[intermediate_size] * 2,
|
output_sizes=[intermediate_size] * 2,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
input_size=intermediate_size,
|
input_size=intermediate_size,
|
||||||
output_size=hidden_size,
|
output_size=hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -299,12 +302,14 @@ class ChameleonAttention(nn.Module):
|
|||||||
total_num_kv_heads=self.total_num_kv_heads,
|
total_num_kv_heads=self.total_num_kv_heads,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
input_size=self.total_num_heads * self.head_dim,
|
input_size=self.total_num_heads * self.head_dim,
|
||||||
output_size=hidden_size,
|
output_size=hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
|
self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
|
||||||
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
|
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
|
||||||
@ -393,6 +398,7 @@ class ChameleonDecoderLayer(nn.Module):
|
|||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
bias=getattr(config, "mlp_bias", False),
|
bias=getattr(config, "mlp_bias", False),
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(
|
self.post_attention_layernorm = RMSNorm(
|
||||||
@ -462,6 +468,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
|||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
bias=getattr(config, "mlp_bias", False),
|
bias=getattr(config, "mlp_bias", False),
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(
|
self.post_attention_layernorm = RMSNorm(
|
||||||
|
|||||||
@ -209,12 +209,14 @@ class DbrxAttention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.Wqkv",
|
||||||
)
|
)
|
||||||
self.out_proj = RowParallelLinear(
|
self.out_proj = RowParallelLinear(
|
||||||
self.d_model,
|
self.d_model,
|
||||||
self.d_model,
|
self.d_model,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|||||||
@ -82,7 +82,11 @@ class DeepseekMLP(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
@ -90,6 +94,7 @@ class DeepseekMLP(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=reduce_results,
|
reduce_results=reduce_results,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -239,6 +244,7 @@ class DeepseekAttention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
@ -246,6 +252,7 @@ class DeepseekAttention(nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
|
|||||||
@ -240,6 +240,7 @@ class Dots1Attention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=attention_bias,
|
bias=attention_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
@ -247,6 +248,7 @@ class Dots1Attention(nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
|
|||||||
@ -137,6 +137,7 @@ class FalconAttention(nn.Module):
|
|||||||
bias=config.bias,
|
bias=config.bias,
|
||||||
skip_bias_add=True,
|
skip_bias_add=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.query_key_value",
|
||||||
)
|
)
|
||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
@ -153,6 +154,7 @@ class FalconAttention(nn.Module):
|
|||||||
skip_bias_add=True,
|
skip_bias_add=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=self.reduce_row_parallel_results,
|
reduce_results=self.reduce_row_parallel_results,
|
||||||
|
prefix=f"{prefix}.dense",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_rotary = config.rotary
|
self.use_rotary = config.rotary
|
||||||
@ -227,6 +229,7 @@ class FalconMLP(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: FalconConfig,
|
config: FalconConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
@ -237,6 +240,7 @@ class FalconMLP(nn.Module):
|
|||||||
bias=config.bias,
|
bias=config.bias,
|
||||||
skip_bias_add=True,
|
skip_bias_add=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense_h_to_4h",
|
||||||
)
|
)
|
||||||
self.act = get_act_fn("gelu")
|
self.act = get_act_fn("gelu")
|
||||||
self.reduce_row_parallel_results = not (
|
self.reduce_row_parallel_results = not (
|
||||||
@ -249,6 +253,7 @@ class FalconMLP(nn.Module):
|
|||||||
skip_bias_add=True,
|
skip_bias_add=True,
|
||||||
reduce_results=self.reduce_row_parallel_results,
|
reduce_results=self.reduce_row_parallel_results,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense_4h_to_h",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -275,7 +280,7 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
self.self_attention = FalconAttention(
|
self.self_attention = FalconAttention(
|
||||||
config, cache_config, quant_config, prefix=f"{prefix}.self_attention"
|
config, cache_config, quant_config, prefix=f"{prefix}.self_attention"
|
||||||
)
|
)
|
||||||
self.mlp = FalconMLP(config, quant_config)
|
self.mlp = FalconMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
if not hasattr(config, "num_ln_in_parallel_attn"):
|
if not hasattr(config, "num_ln_in_parallel_attn"):
|
||||||
|
|||||||
@ -59,6 +59,7 @@ class FalconH1MLP(nn.Module):
|
|||||||
config: FalconH1Config,
|
config: FalconH1Config,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
@ -66,12 +67,14 @@ class FalconH1MLP(nn.Module):
|
|||||||
output_sizes=[config.intermediate_size] * 2,
|
output_sizes=[config.intermediate_size] * 2,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
input_size=config.intermediate_size,
|
input_size=config.intermediate_size,
|
||||||
output_size=config.hidden_size,
|
output_size=config.hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
)
|
)
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
@ -365,7 +368,7 @@ class FalconH1ParallelHybrid(nn.Module):
|
|||||||
self.attention_in_multiplier = config.attention_in_multiplier
|
self.attention_in_multiplier = config.attention_in_multiplier
|
||||||
self.attn_out_multiplier = config.attention_out_multiplier
|
self.attn_out_multiplier = config.attention_out_multiplier
|
||||||
|
|
||||||
self.feed_forward = FalconH1MLP(config)
|
self.feed_forward = FalconH1MLP(config, prefix=f"{prefix}.feed_forward")
|
||||||
|
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|||||||
@ -66,13 +66,22 @@ class Gemma2MLP(nn.Module):
|
|||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
hidden_activation: str,
|
hidden_activation: str,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
)
|
)
|
||||||
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
|
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -134,12 +143,14 @@ class Gemma2Attention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=config.attention_bias,
|
bias=config.attention_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=config.attention_bias,
|
bias=config.attention_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -208,6 +219,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
hidden_activation=config.hidden_activation,
|
hidden_activation=config.hidden_activation,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = GemmaRMSNorm(
|
self.post_attention_layernorm = GemmaRMSNorm(
|
||||||
|
|||||||
@ -78,12 +78,14 @@ class GPTJAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.out_proj = RowParallelLinear(
|
self.out_proj = RowParallelLinear(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
@ -130,6 +132,7 @@ class GPTJMLP(nn.Module):
|
|||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
config: GPTJConfig,
|
config: GPTJConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.n_embd
|
hidden_size = config.n_embd
|
||||||
@ -137,11 +140,13 @@ class GPTJMLP(nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.fc_in",
|
||||||
)
|
)
|
||||||
self.fc_out = RowParallelLinear(
|
self.fc_out = RowParallelLinear(
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.fc_out",
|
||||||
)
|
)
|
||||||
self.act = get_act_fn(config.activation_function)
|
self.act = get_act_fn(config.activation_function)
|
||||||
|
|
||||||
@ -166,7 +171,7 @@ class GPTJBlock(nn.Module):
|
|||||||
self.attn = GPTJAttention(
|
self.attn = GPTJAttention(
|
||||||
config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
||||||
)
|
)
|
||||||
self.mlp = GPTJMLP(inner_dim, config, quant_config)
|
self.mlp = GPTJMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -80,12 +80,14 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.query_key_value",
|
||||||
)
|
)
|
||||||
self.dense = RowParallelLinear(
|
self.dense = RowParallelLinear(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense",
|
||||||
)
|
)
|
||||||
scaling = self.head_size**-0.5
|
scaling = self.head_size**-0.5
|
||||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||||
@ -125,17 +127,20 @@ class GPTNeoXMLP(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: GPTNeoXConfig,
|
config: GPTNeoXConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense_h_to_4h = ColumnParallelLinear(
|
self.dense_h_to_4h = ColumnParallelLinear(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense_h_to_4h",
|
||||||
)
|
)
|
||||||
self.dense_4h_to_h = RowParallelLinear(
|
self.dense_4h_to_h = RowParallelLinear(
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense_4h_to_h",
|
||||||
)
|
)
|
||||||
self.act = get_act_fn(config.hidden_act)
|
self.act = get_act_fn(config.hidden_act)
|
||||||
|
|
||||||
|
|||||||
@ -107,12 +107,14 @@ class JAISAttention(nn.Module):
|
|||||||
total_num_heads,
|
total_num_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.c_attn",
|
||||||
)
|
)
|
||||||
self.c_proj = RowParallelLinear(
|
self.c_proj = RowParallelLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.c_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
@ -147,6 +149,7 @@ class JAISMLP(nn.Module):
|
|||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
config: JAISConfig,
|
config: JAISConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
@ -156,6 +159,7 @@ class JAISMLP(nn.Module):
|
|||||||
intermediate_size,
|
intermediate_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.c_fc",
|
||||||
)
|
)
|
||||||
self.c_fc2 = (
|
self.c_fc2 = (
|
||||||
ColumnParallelLinear(
|
ColumnParallelLinear(
|
||||||
@ -163,6 +167,7 @@ class JAISMLP(nn.Module):
|
|||||||
intermediate_size,
|
intermediate_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.c_fc2",
|
||||||
)
|
)
|
||||||
if self.swiglu
|
if self.swiglu
|
||||||
else None
|
else None
|
||||||
@ -172,6 +177,7 @@ class JAISMLP(nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.c_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.act = SwiGLUActivation()
|
self.act = SwiGLUActivation()
|
||||||
@ -206,7 +212,7 @@ class JAISBlock(nn.Module):
|
|||||||
config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
||||||
)
|
)
|
||||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.mlp = JAISMLP(inner_dim, config, quant_config)
|
self.mlp = JAISMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -220,12 +220,14 @@ class JambaAttentionDecoderLayer(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
|
|||||||
@ -191,13 +191,22 @@ class MiniCPMMLP(nn.Module):
|
|||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
hidden_act_param: float,
|
hidden_act_param: float,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
)
|
)
|
||||||
if hidden_act == "silu":
|
if hidden_act == "silu":
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
@ -259,12 +268,14 @@ class MiniCPMAttention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
|
|||||||
@ -96,6 +96,7 @@ class MiniCPM3Attention(nn.Module):
|
|||||||
self.num_heads * self.qk_head_dim,
|
self.num_heads * self.qk_head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_b_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||||
@ -103,6 +104,7 @@ class MiniCPM3Attention(nn.Module):
|
|||||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||||
)
|
)
|
||||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||||
self.kv_b_proj = ColumnParallelLinear(
|
self.kv_b_proj = ColumnParallelLinear(
|
||||||
@ -110,6 +112,7 @@ class MiniCPM3Attention(nn.Module):
|
|||||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.kv_b_proj",
|
||||||
)
|
)
|
||||||
# O projection.
|
# O projection.
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
@ -117,6 +120,7 @@ class MiniCPM3Attention(nn.Module):
|
|||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
|
|||||||
@ -83,6 +83,7 @@ class MPTAttention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=not config.no_bias,
|
bias=not config.no_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.Wqkv",
|
||||||
)
|
)
|
||||||
if self.qk_ln:
|
if self.qk_ln:
|
||||||
self.q_ln = nn.LayerNorm(self.d_model)
|
self.q_ln = nn.LayerNorm(self.d_model)
|
||||||
@ -92,6 +93,7 @@ class MPTAttention(nn.Module):
|
|||||||
self.d_model,
|
self.d_model,
|
||||||
bias=not config.no_bias,
|
bias=not config.no_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
@ -152,6 +154,7 @@ class MPTMLP(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: MptConfig,
|
config: MptConfig,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.d_model
|
hidden_size = config.d_model
|
||||||
@ -162,6 +165,7 @@ class MPTMLP(nn.Module):
|
|||||||
intermediate_size,
|
intermediate_size,
|
||||||
bias=not config.no_bias,
|
bias=not config.no_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.up_proj",
|
||||||
)
|
)
|
||||||
self.act = get_act_fn("gelu")
|
self.act = get_act_fn("gelu")
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
@ -169,6 +173,7 @@ class MPTMLP(nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
bias=not config.no_bias,
|
bias=not config.no_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -193,7 +198,7 @@ class MPTBlock(nn.Module):
|
|||||||
config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
||||||
)
|
)
|
||||||
self.norm_2 = nn.LayerNorm(hidden_size)
|
self.norm_2 = nn.LayerNorm(hidden_size)
|
||||||
self.ffn = MPTMLP(config, quant_config)
|
self.ffn = MPTMLP(config, quant_config, prefix=f"{prefix}.ffn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -158,6 +158,7 @@ class OlmoeAttention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
@ -168,6 +169,7 @@ class OlmoeAttention(nn.Module):
|
|||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
|
|||||||
@ -52,13 +52,22 @@ class OrionMLP(nn.Module):
|
|||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -116,12 +125,14 @@ class OrionAttention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
@ -183,6 +194,7 @@ class OrionDecoderLayer(nn.Module):
|
|||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|||||||
@ -62,14 +62,23 @@ from .utils import (
|
|||||||
|
|
||||||
class PersimmonMLP(nn.Module):
|
class PersimmonMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: PersimmonConfig, quant_config: QuantizationConfig | None = None
|
self,
|
||||||
|
config: PersimmonConfig,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense_h_to_4h = ColumnParallelLinear(
|
self.dense_h_to_4h = ColumnParallelLinear(
|
||||||
config.hidden_size, config.intermediate_size, quant_config=quant_config
|
config.hidden_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense_h_to_4h",
|
||||||
)
|
)
|
||||||
self.dense_4h_to_h = RowParallelLinear(
|
self.dense_4h_to_h = RowParallelLinear(
|
||||||
config.intermediate_size, config.hidden_size, quant_config=quant_config
|
config.intermediate_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense_4h_to_h",
|
||||||
)
|
)
|
||||||
self.act = get_act_fn(config.hidden_act)
|
self.act = get_act_fn(config.hidden_act)
|
||||||
|
|
||||||
@ -110,12 +119,14 @@ class PersimmonAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.query_key_value",
|
||||||
)
|
)
|
||||||
self.dense = RowParallelLinear(
|
self.dense = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense",
|
||||||
)
|
)
|
||||||
self.is_qk_layernorm = config.qk_layernorm
|
self.is_qk_layernorm = config.qk_layernorm
|
||||||
|
|
||||||
@ -192,7 +203,11 @@ class PersimmonDecoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
self.mlp = PersimmonMLP(config, quant_config=quant_config)
|
self.mlp = PersimmonMLP(
|
||||||
|
config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
self.input_layernorm = nn.LayerNorm(
|
self.input_layernorm = nn.LayerNorm(
|
||||||
config.hidden_size, eps=config.layer_norm_eps
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
|
|||||||
@ -99,11 +99,13 @@ class PhiAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.dense = RowParallelLinear(
|
self.dense = RowParallelLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.dense",
|
||||||
)
|
)
|
||||||
|
|
||||||
scaling = self.head_size**-0.5
|
scaling = self.head_size**-0.5
|
||||||
@ -148,7 +150,10 @@ class PhiAttention(nn.Module):
|
|||||||
|
|
||||||
class PhiMLP(nn.Module):
|
class PhiMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: PhiConfig, quant_config: QuantizationConfig | None = None
|
self,
|
||||||
|
config: PhiConfig,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -159,11 +164,13 @@ class PhiMLP(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
n_inner,
|
n_inner,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.fc1",
|
||||||
)
|
)
|
||||||
self.fc2 = RowParallelLinear(
|
self.fc2 = RowParallelLinear(
|
||||||
n_inner,
|
n_inner,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.fc2",
|
||||||
)
|
)
|
||||||
self.act = get_act_fn(config.hidden_act)
|
self.act = get_act_fn(config.hidden_act)
|
||||||
|
|
||||||
@ -189,7 +196,7 @@ class PhiLayer(nn.Module):
|
|||||||
self.self_attn = PhiAttention(
|
self.self_attn = PhiAttention(
|
||||||
config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
|
config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
|
||||||
)
|
)
|
||||||
self.mlp = PhiMLP(config, quant_config)
|
self.mlp = PhiMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -343,12 +343,14 @@ class PhiMoEAttention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|||||||
@ -567,12 +567,14 @@ class Plamo2AttentionMixer(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000
|
self.rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000
|
||||||
|
|||||||
@ -102,12 +102,14 @@ class QWenAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.c_attn",
|
||||||
)
|
)
|
||||||
self.c_proj = RowParallelLinear(
|
self.c_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.c_proj",
|
||||||
)
|
)
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
|||||||
@ -75,7 +75,12 @@ class Zamba2LoRA(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.A = ColumnParallelLinear(
|
self.A = ColumnParallelLinear(
|
||||||
input_dim, rank, bias=False, quant_config=quant_config, gather_output=True
|
input_dim,
|
||||||
|
rank,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
gather_output=True,
|
||||||
|
prefix=f"{prefix}.A",
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(output_dim, list):
|
if isinstance(output_dim, list):
|
||||||
@ -150,12 +155,14 @@ class Zamba2Attention(nn.Module):
|
|||||||
self.total_num_attention_heads,
|
self.total_num_attention_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.attention_hidden_size,
|
self.attention_hidden_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Even though in Zamba2 weights are shared between attention layers, KV
|
# Even though in Zamba2 weights are shared between attention layers, KV
|
||||||
@ -197,18 +204,21 @@ class Zamba2Attention(nn.Module):
|
|||||||
config.adapter_rank,
|
config.adapter_rank,
|
||||||
self.attention_hidden_size,
|
self.attention_hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.linear_q_adapter",
|
||||||
)
|
)
|
||||||
linear_k_adapter = Zamba2LoRA(
|
linear_k_adapter = Zamba2LoRA(
|
||||||
self.attention_hidden_size,
|
self.attention_hidden_size,
|
||||||
config.adapter_rank,
|
config.adapter_rank,
|
||||||
self.attention_hidden_size,
|
self.attention_hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.linear_k_adapter",
|
||||||
)
|
)
|
||||||
linear_v_adapter = Zamba2LoRA(
|
linear_v_adapter = Zamba2LoRA(
|
||||||
self.attention_hidden_size,
|
self.attention_hidden_size,
|
||||||
config.adapter_rank,
|
config.adapter_rank,
|
||||||
self.attention_hidden_size,
|
self.attention_hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.linear_v_adapter",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
linear_q_adapter = nn.Identity()
|
linear_q_adapter = nn.Identity()
|
||||||
@ -312,6 +322,7 @@ class Zamba2MLP(nn.Module):
|
|||||||
2 * [self.intermediate_size], # 2x for gate and input projections
|
2 * [self.intermediate_size], # 2x for gate and input projections
|
||||||
bias=self.config.add_bias_linear,
|
bias=self.config.add_bias_linear,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
@ -319,6 +330,7 @@ class Zamba2MLP(nn.Module):
|
|||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=self.config.add_bias_linear,
|
bias=self.config.add_bias_linear,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only allow GELU activations
|
# Only allow GELU activations
|
||||||
@ -418,6 +430,7 @@ class Zamba2AttentionDecoderLayer(nn.Module):
|
|||||||
bare_block_idx=bare_block_idx,
|
bare_block_idx=bare_block_idx,
|
||||||
num_hybrid_layers=num_hybrid_layers,
|
num_hybrid_layers=num_hybrid_layers,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.feed_forward",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize layer normalizations
|
# Initialize layer normalizations
|
||||||
@ -599,6 +612,7 @@ class Zamba2HybridLayer(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.linear",
|
||||||
)
|
)
|
||||||
self.mamba_decoder = Zamba2MambaDecoderLayer(
|
self.mamba_decoder = Zamba2MambaDecoderLayer(
|
||||||
config,
|
config,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user