Fix loading of quantized BigCode models (#22463)

Signed-off-by: Eldar Kurtic <eldar@neuralmagic.com>
This commit is contained in:
Eldar Kurtić 2025-08-09 08:12:12 +02:00 committed by GitHub
parent 65552b476b
commit 10a02535d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,7 +45,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GPTBigCodeAttention(nn.Module):
@ -83,6 +84,7 @@ class GPTBigCodeAttention(nn.Module):
total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_attn",
)
self.c_proj = RowParallelLinear(
@ -90,6 +92,7 @@ class GPTBigCodeAttention(nn.Module):
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.attn = Attention(self.num_heads,
self.head_dim,
@ -123,6 +126,7 @@ class GPTBigMLP(nn.Module):
intermediate_size: int,
config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
@ -131,12 +135,14 @@ class GPTBigMLP(nn.Module):
intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_fc",
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.act = get_act_fn(config.activation_function)
@ -167,7 +173,10 @@ class GPTBigCodeBlock(nn.Module):
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
self.mlp = GPTBigMLP(inner_dim,
config,
quant_config,
prefix=f"{prefix}.mlp")
def forward(
self,
@ -260,7 +269,7 @@ class GPTBigCodeModel(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
if "c_attn.input_scale" in name:
weight_loader(param, loaded_weight, 'q')
weight_loader(param, loaded_weight, 'k')
weight_loader(param, loaded_weight, 'v')
@ -284,7 +293,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
prefix=prefix)
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else: