From 3987e2ae963963b9edb132935deabd16dd5a7468 Mon Sep 17 00:00:00 2001 From: iLeGend Date: Fri, 30 May 2025 12:50:10 +0800 Subject: [PATCH] [Model] Use AutoWeightsLoader for mamba2 (#18918) Signed-off-by: iLeGend <824040212@qq.com> --- vllm/model_executor/models/mamba2.py | 43 ++++++++++++++++------------ 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 858a1633befa..65c6467bcf5f 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -32,7 +32,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -167,6 +167,27 @@ class Mamba2Model(nn.Module): return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "A_log" in name: + name = name.replace("A_log", "A") + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsV0Only): @@ -282,21 +303,5 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "A_log" in name: - name = name.replace("A_log", "A") - - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights)