From 08abfa78ec2fca59889045bb5aec41e383dac9bd Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Wed, 10 Sep 2025 20:20:46 +0300 Subject: [PATCH] [Bugfix] fix modelopt exclude_modules name mapping (#24178) Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Cyrus Leung --- .../layers/mamba/mamba_mixer2.py | 3 + .../layers/quantization/modelopt.py | 24 +++++-- vllm/model_executor/models/nemotron_h.py | 68 ++++++++++--------- 3 files changed, 58 insertions(+), 37 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index bb3fdd38dbef3..04ebdbca85e5d 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -291,6 +291,7 @@ class MambaMixer2(MambaBase, CustomOp): output_size=self.conv_dim, bias=use_conv_bias, quant_config=None, + prefix=f"{prefix}.conv1d", ) # unsqueeze to fit conv1d weights shape into the linear weights shape. # Can't do this in `weight_loader` since it already exists in @@ -303,6 +304,7 @@ class MambaMixer2(MambaBase, CustomOp): output_size=intermediate_size + self.conv_dim + self.num_heads, bias=use_bias, quant_config=quant_config, + prefix=f"{prefix}.in_proj", ) # - because in_proj is a concatenation of 3 weights, we @@ -402,6 +404,7 @@ class MambaMixer2(MambaBase, CustomOp): bias=use_bias, input_is_parallel=True, quant_config=quant_config, + prefix=f"{prefix}.out_proj", ) self.norm = Mixer2RMSNormGated(intermediate_size, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e140807879177..9b99931e7b43f 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch.nn import Module @@ -45,6 +45,9 @@ from vllm.utils import next_power_of_2 from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer, has_flashinfer_moe) +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + logger = init_logger(__name__) QUANT_ALGOS = ["FP8", "NVFP4"] @@ -63,7 +66,7 @@ class ModelOptFp8Config(QuantizationConfig): super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.kv_cache_quant_method = kv_cache_quant_method - self.exclude_modules = exclude_modules + self.exclude_modules = exclude_modules or [] if is_checkpoint_fp8_serialized: logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" " the format is experimental and could change.") @@ -84,6 +87,11 @@ class ModelOptFp8Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.exclude_modules is not None: + self.exclude_modules = hf_to_vllm_mapper.apply_list( + self.exclude_modules) + @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: @@ -170,7 +178,9 @@ class ModelOptFp8Config(QuantizationConfig): prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): - if self.is_layer_excluded(prefix): + if (is_layer_skipped(prefix, self.exclude_modules, + self.packed_modules_mapping) + or self.is_layer_excluded(prefix)): return UnquantizedLinearMethod() return ModelOptFp8LinearMethod(self) elif isinstance(layer, Attention): @@ -615,6 +625,11 @@ class ModelOptNvFp4Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.exclude_modules is not None: + self.exclude_modules = hf_to_vllm_mapper.apply_list( + self.exclude_modules) + @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: @@ -763,7 +778,8 @@ class ModelOptNvFp4Config(QuantizationConfig): prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): - if (is_layer_skipped(prefix, self.exclude_modules) + if (is_layer_skipped(prefix, self.exclude_modules, + self.packed_modules_mapping) or self.is_layer_excluded(prefix, self.exclude_modules)): return UnquantizedLinearMethod() return ModelOptNvFp4LinearMethod(self) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 8a563288cb4d6..da8628df1fe57 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -44,15 +44,16 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.models.utils import ( - AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory, + make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig @@ -426,38 +427,36 @@ class NemotronHModel(nn.Module): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - attb_params_mapping = { - "q_proj": "q", - "k_proj": "k", - "v_proj": "v", - } + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "embeddings" in name: - name = name.replace("embeddings", "embed_tokens") + if "scale" in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue - if "A_log" in name: - name = name.replace("A_log", "A") - loaded_weight = loaded_weight.to(torch.float32) + # load stacked params + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue - if "D" in name: - loaded_weight = loaded_weight.to(torch.float32) - - if "dt_bias" in name: - loaded_weight = loaded_weight.to(torch.float32) - - # load attn params - if any(proj in name for proj in ["q_proj", "k_proj", "v_proj"]): - weight_name = next(proj - for proj in ["q_proj", "k_proj", "v_proj"] - if proj in name) - name = name.replace(weight_name, "qkv_proj") param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, - attb_params_mapping[weight_name]) + weight_loader(param, loaded_weight, shard_id) + break + # load other params else: param = params_dict[name] @@ -471,6 +470,14 @@ class NemotronHModel(nn.Module): class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"backbone": "model"}, + orig_to_new_substr={ + "A_log": "A", + "embeddings": "embed_tokens" + }, + ) + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -622,10 +629,5 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # update name in weights before passing to loader - updated_weights = [] - for name, loaded_weight in weights: - name = name.replace("backbone", "model") - updated_weights.append((name, loaded_weight)) loader = AutoWeightsLoader(self) - return loader.load_weights(updated_weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)