[Quantization]add prefix for commandA quantized model (#17017)

This commit is contained in:
Chen Xia 2025-04-23 17:32:40 -07:00 committed by GitHub
parent b07d741661
commit 6b2427f995
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -89,6 +89,7 @@ class CohereMLP(nn.Module):
self,
config: CohereConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@ -99,12 +100,14 @@ class CohereMLP(nn.Module):
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = SiluAndMul()
@ -158,12 +161,14 @@ class CohereAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
@ -244,7 +249,9 @@ class CohereDecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = CohereMLP(config, quant_config=quant_config)
self.mlp = CohereMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)