From 242a637aead7d5a60a65232d51d3a091fb918925 Mon Sep 17 00:00:00 2001 From: "rongfu.leng" Date: Sun, 6 Apr 2025 20:52:01 +0800 Subject: [PATCH] [Model] use AutoWeightsLoader for stablelm,starcoder2,zamba2 (#16103) Signed-off-by: rongfu.leng --- vllm/model_executor/models/stablelm.py | 94 +++++++++++++----------- vllm/model_executor/models/starcoder2.py | 84 +++++++++++---------- vllm/model_executor/models/zamba2.py | 78 ++++++++++---------- 3 files changed, 135 insertions(+), 121 deletions(-) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index a15faec547b95..53f520304abc4 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -44,7 +44,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -253,6 +253,45 @@ class StableLMEpochModel(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class StablelmForCausalLM(nn.Module, SupportsPP): @@ -308,46 +347,13 @@ class StablelmForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + skip_prefixes=[ + "rotary_emb.inv_freq", "rotary_emb.cos_cached", + "rotary_emb.sin_cached" + ], + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 3d11dfd779210..8b9fb7cb7bc6e 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -45,7 +45,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -256,6 +256,41 @@ class Starcoder2Model(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class Starcoder2ForCausalLM(nn.Module, SupportsPP): @@ -319,41 +354,12 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + skip_prefixes=([ + "rotary_emb.inv_freq", "lm_head.weight" + ] if self.config.tie_word_embeddings else ["rotary_emb.inv_freq"]), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 7e210244f794d..c5330203baca8 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -39,7 +39,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid, SupportsV0Only -from .utils import maybe_prefix +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix class Zamba2LoRA(nn.Module): @@ -777,6 +777,37 @@ class Zamba2Model(nn.Module): hidden_states = self.final_layernorm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for chkpt_weight_name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in chkpt_weight_name: + continue + chkpt_weight_name = chkpt_weight_name.replace( + weight_name, param_name) + param = params_dict[chkpt_weight_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if chkpt_weight_name not in params_dict: + continue + param = params_dict[chkpt_weight_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(chkpt_weight_name) + return loaded_params + class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): """Zamba2 model with causal language modeling head. @@ -787,6 +818,12 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): - Support for model parallelism and quantization - Sampling capabilities for text generation """ + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ + "A_log": "A", + "0.weight": "A.weight", + "1.weight": "B.weight", + }) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: """Initialize the Zamba2 model for causal language modeling. @@ -992,40 +1029,5 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - weights_dict = {} - for key, loaded_weight in weights: - if "A_log" in key: - key = key.replace("A_log", "A") - elif "adapter_list" in key: - key = key.replace("0.weight", "A.weight") - key = key.replace("1.weight", "B.weight") - weights_dict[key] = loaded_weight - - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for chkpt_weight_name, loaded_weight in weights_dict.items(): - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in chkpt_weight_name: - continue - chkpt_weight_name = chkpt_weight_name.replace( - weight_name, param_name) - param = params_dict[chkpt_weight_name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - if chkpt_weight_name not in params_dict: - continue - param = params_dict[chkpt_weight_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(chkpt_weight_name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)