mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-18 17:24:40 +08:00
[Model] use AutoWeightsLoader for baichuan, gpt-neox, mpt (#15939)
Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>
This commit is contained in:
parent
a35a8a8392
commit
bf7e3c51ae
@ -47,7 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
@ -321,6 +321,45 @@ class BaiChuanModel(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
SupportsQuant):
|
||||
@ -353,6 +392,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.lm_head.weight.weight_loader = self.lm_head_weight_loader
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
@ -393,53 +433,22 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if name == "lm_head.weight":
|
||||
# Unlike Baichuan, Baichuan2 normalizes the head weights.
|
||||
# Refer to:
|
||||
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
|
||||
# Distinguish between Baichuan and Baichuan2 by checking the
|
||||
# vocab size. This is suggested by
|
||||
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
|
||||
is_baichuan2 = self.config.vocab_size == 125696
|
||||
if is_baichuan2:
|
||||
loaded_weight = torch.nn.functional.normalize(
|
||||
loaded_weight)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
def lm_head_weight_loader(self, param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
# Unlike Baichuan, Baichuan2 normalizes the head weights.
|
||||
# Refer to:
|
||||
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
|
||||
# Distinguish between Baichuan and Baichuan2 by checking the
|
||||
# vocab size. This is suggested by
|
||||
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
|
||||
is_baichuan2 = self.config.vocab_size == 125696
|
||||
if is_baichuan2:
|
||||
loaded_weight = torch.nn.functional.normalize(loaded_weight)
|
||||
|
||||
default_weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
|
||||
@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -241,6 +241,45 @@ class GPTNeoXModel(nn.Module):
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using OpenRLHF may include
|
||||
# these tensors in the checkpoint. Skip them.
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
|
||||
# (num_heads * 3 * head_size), while the
|
||||
# required shape is (3 * num_heads * head_size).
|
||||
# Thus, we need weight conversion.
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
num_heads = self.config.num_attention_heads
|
||||
if output_dim is not None:
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = loaded_weight.transpose(
|
||||
output_dim, output_dim + 1)
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
@ -297,39 +336,5 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using OpenRLHF may include
|
||||
# these tensors in the checkpoint. Skip them.
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
|
||||
# (num_heads * 3 * head_size), while the
|
||||
# required shape is (3 * num_heads * head_size).
|
||||
# Thus, we need weight conversion.
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
num_heads = self.config.num_attention_heads
|
||||
if output_dim is not None:
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = loaded_weight.transpose(
|
||||
output_dim, output_dim + 1)
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -27,7 +27,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -266,6 +266,23 @@ class MPTModel(nn.Module):
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MPTForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
@ -318,17 +335,5 @@ class MPTForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user