mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 08:57:54 +08:00
[Model] use AutoWeightsLoader for BigCode, GPT-J (#16823)
Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>
This commit is contained in:
parent
26507f8973
commit
87e067de41
@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers)
|
make_empty_intermediate_tensors_factory, make_layers)
|
||||||
|
|
||||||
|
|
||||||
@ -244,6 +244,30 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if ".attn.bias" in name:
|
||||||
|
# Skip attention mask.
|
||||||
|
# NOTE: "c_attn.bias" should not be skipped.
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
|
||||||
|
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
|
||||||
|
weight_loader(param, loaded_weight, 'q')
|
||||||
|
weight_loader(param, loaded_weight, 'k')
|
||||||
|
weight_loader(param, loaded_weight, 'v')
|
||||||
|
else:
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
||||||
@ -315,26 +339,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
loader = AutoWeightsLoader(
|
||||||
loaded_params: Set[str] = set()
|
self,
|
||||||
for name, loaded_weight in weights:
|
skip_prefixes=(["lm_head."]),
|
||||||
if "lm_head.weight" in name:
|
)
|
||||||
continue
|
return loader.load_weights(weights)
|
||||||
if ".attn.bias" in name:
|
|
||||||
# Skip attention mask.
|
|
||||||
# NOTE: "c_attn.bias" should not be skipped.
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
|
|
||||||
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
|
|
||||||
weight_loader(param, loaded_weight, 'q')
|
|
||||||
weight_loader(param, loaded_weight, 'k')
|
|
||||||
weight_loader(param, loaded_weight, 'v')
|
|
||||||
else:
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -188,6 +188,7 @@ class GPTJModel(nn.Module):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
self.embed_dim = config.n_embd
|
self.embed_dim = config.n_embd
|
||||||
self.wte = VocabParallelEmbedding(
|
self.wte = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
@ -228,6 +229,63 @@ class GPTJModel(nn.Module):
|
|||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache quantization scales
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class GPTJForCausalLM(nn.Module, SupportsPP):
|
class GPTJForCausalLM(nn.Module, SupportsPP):
|
||||||
|
|
||||||
@ -285,57 +343,5 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(self)
|
||||||
# (param_name, shard_name, shard_id)
|
return loader.load_weights(weights)
|
||||||
("qkv_proj", "q_proj", "q"),
|
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (self.quant_config is not None and
|
|
||||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
|
||||||
# Loading kv cache quantization scales
|
|
||||||
param = params_dict[scale_name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
|
||||||
loaded_weight[0])
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(scale_name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
Loading…
x
Reference in New Issue
Block a user