mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 00:29:42 +08:00
Fix loading of quantized BigCode models (#22463)
Signed-off-by: Eldar Kurtic <eldar@neuralmagic.com>
This commit is contained in:
parent
65552b476b
commit
10a02535d4
@ -45,7 +45,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
@ -83,6 +84,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_attn",
|
||||
)
|
||||
|
||||
self.c_proj = RowParallelLinear(
|
||||
@ -90,6 +92,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
@ -123,6 +126,7 @@ class GPTBigMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -131,12 +135,14 @@ class GPTBigMLP(nn.Module):
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
@ -167,7 +173,10 @@ class GPTBigCodeBlock(nn.Module):
|
||||
quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
||||
self.mlp = GPTBigMLP(inner_dim,
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -260,7 +269,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
|
||||
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
|
||||
if "c_attn.input_scale" in name:
|
||||
weight_loader(param, loaded_weight, 'q')
|
||||
weight_loader(param, loaded_weight, 'k')
|
||||
weight_loader(param, loaded_weight, 'v')
|
||||
@ -284,7 +293,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
|
||||
prefix=prefix)
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.transformer.wte
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user