From 87d871470de15bcf238b6e8afcbe2d2e14f21fae Mon Sep 17 00:00:00 2001 From: learner0810 <39400425+learner0810@users.noreply.github.com> Date: Fri, 16 May 2025 22:54:13 +0800 Subject: [PATCH] [Model] Use autoweightloader for dbrx (#18251) Signed-off-by: learner0810 --- vllm/model_executor/models/dbrx.py | 100 +++++++++++++++-------------- 1 file changed, 53 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index e0b4712cdb47b..f21887f71d857 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dbrx import DbrxConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -319,6 +319,7 @@ class DbrxModel(nn.Module): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + self.quant_config = quant_config self.wte = VocabParallelEmbedding( config.vocab_size, config.d_model, @@ -364,6 +365,55 @@ class DbrxModel(nn.Module): hidden_states = self.norm_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + expert_params_mapping = [( + "w13" if weight_name in ["w1", "v1"] else "w2", + f"mlp.{weight_name}", + ) for weight_name in ["w1", "v1", "w2"]] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + if name.endswith(("w1", "w2", "v1")): + name = name + "_weight" + for param_name, weight_name in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, weight_name, name) + break + + else: + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class DbrxForCausalLM(nn.Module, SupportsPP): @@ -417,49 +467,5 @@ class DbrxForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - expert_params_mapping = [( - "w13" if weight_name in ["w1", "v1"] else "w2", - f"mlp.{weight_name}", - ) for weight_name in ["w1", "v1", "w2"]] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - - if name.endswith(("w1", "w2", "v1")): - name = name + "_weight" - for param_name, weight_name in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, weight_name, name) - break - - else: - if is_pp_missing_parameter(name, self): - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights)