diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 661a67bdc0db0..036ded530f97d 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -45,7 +45,8 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class GPTBigCodeAttention(nn.Module): @@ -83,6 +84,7 @@ class GPTBigCodeAttention(nn.Module): total_num_kv_heads, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_attn", ) self.c_proj = RowParallelLinear( @@ -90,6 +92,7 @@ class GPTBigCodeAttention(nn.Module): self.hidden_size, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_proj", ) self.attn = Attention(self.num_heads, self.head_dim, @@ -123,6 +126,7 @@ class GPTBigMLP(nn.Module): intermediate_size: int, config: GPTBigCodeConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size @@ -131,12 +135,14 @@ class GPTBigMLP(nn.Module): intermediate_size, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_fc", ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_proj", ) self.act = get_act_fn(config.activation_function) @@ -167,7 +173,10 @@ class GPTBigCodeBlock(nn.Module): quant_config, prefix=f"{prefix}.attn") self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPTBigMLP(inner_dim, config, quant_config) + self.mlp = GPTBigMLP(inner_dim, + config, + quant_config, + prefix=f"{prefix}.mlp") def forward( self, @@ -260,7 +269,7 @@ class GPTBigCodeModel(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method - if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: + if "c_attn.input_scale" in name: weight_loader(param, loaded_weight, 'q') weight_loader(param, loaded_weight, 'k') weight_loader(param, loaded_weight, 'v') @@ -284,7 +293,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.quant_config = quant_config self.transformer = GPTBigCodeModel(vllm_config=vllm_config, - prefix=prefix) + prefix=maybe_prefix( + prefix, "transformer")) if self.config.tie_word_embeddings: self.lm_head = self.transformer.wte else: