mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 12:25:01 +08:00
[Model] use AutoWeightsLoader for phi, gemma, deepseek (#16088)
Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>
This commit is contained in:
parent
2fa66ef713
commit
29283eaa7e
@ -51,7 +51,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (extract_layer_index, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||||
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -385,6 +386,56 @@ class DeepseekModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Skip experts that are not assigned to this worker.
|
||||||
|
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Skip experts that are not assigned to this worker.
|
||||||
|
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class DeepseekForCausalLM(nn.Module, SupportsPP):
|
class DeepseekForCausalLM(nn.Module, SupportsPP):
|
||||||
|
|
||||||
@ -439,50 +490,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(self)
|
||||||
# (param_name, shard_name, shard_id)
|
return loader.load_weights(weights)
|
||||||
("qkv_proj", "q_proj", "q"),
|
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
# Skip experts that are not assigned to this worker.
|
|
||||||
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
|
||||||
and name not in params_dict):
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
# Skip experts that are not assigned to this worker.
|
|
||||||
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
|
||||||
and name not in params_dict):
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -319,6 +319,46 @@ class GemmaModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||||
|
if shard_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(shard_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@ -385,44 +425,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(
|
||||||
# (param_name, shard_name, shard_id)
|
self,
|
||||||
("qkv_proj", "q_proj", "q"),
|
skip_prefixes=(["lm_head."]
|
||||||
("qkv_proj", "k_proj", "k"),
|
if self.config.tie_word_embeddings else None),
|
||||||
("qkv_proj", "v_proj", "v"),
|
)
|
||||||
("gate_up_proj", "gate_proj", 0),
|
return loader.load_weights(weights)
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
|
||||||
if shard_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(shard_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# lm_head is not used in vllm as it is tied with embed_token.
|
|
||||||
# To prevent errors, skip loading lm_head.weight.
|
|
||||||
if "lm_head.weight" in name:
|
|
||||||
continue
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
|
|
||||||
return loaded_params
|
|
||||||
|
|||||||
@ -61,7 +61,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -249,6 +249,49 @@ class PhiModel(nn.Module):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v")
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# pylint: disable=E1136
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@ -317,43 +360,5 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(self)
|
||||||
# (param_name, shard_name, shard_id)
|
return loader.load_weights(weights)
|
||||||
("qkv_proj", "q_proj", "q"),
|
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v")
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
# pylint: disable=E1136
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user