[Bugfix] fix modelopt exclude_modules name mapping (#24178)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
tomeras91 2025-09-10 20:20:46 +03:00 committed by GitHub
parent 2bef2d1405
commit 08abfa78ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 58 additions and 37 deletions

View File

@ -291,6 +291,7 @@ class MambaMixer2(MambaBase, CustomOp):
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
@ -303,6 +304,7 @@ class MambaMixer2(MambaBase, CustomOp):
output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
# - because in_proj is a concatenation of 3 weights, we
@ -402,6 +404,7 @@ class MambaMixer2(MambaBase, CustomOp):
bias=use_bias,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.norm = Mixer2RMSNormGated(intermediate_size,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from torch.nn import Module
@ -45,6 +45,9 @@ from vllm.utils import next_power_of_2
from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer,
has_flashinfer_moe)
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
QUANT_ALGOS = ["FP8", "NVFP4"]
@ -63,7 +66,7 @@ class ModelOptFp8Config(QuantizationConfig):
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.kv_cache_quant_method = kv_cache_quant_method
self.exclude_modules = exclude_modules
self.exclude_modules = exclude_modules or []
if is_checkpoint_fp8_serialized:
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change.")
@ -84,6 +87,11 @@ class ModelOptFp8Config(QuantizationConfig):
def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"]
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.exclude_modules is not None:
self.exclude_modules = hf_to_vllm_mapper.apply_list(
self.exclude_modules)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
@ -170,7 +178,9 @@ class ModelOptFp8Config(QuantizationConfig):
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if self.is_layer_excluded(prefix):
if (is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping)
or self.is_layer_excluded(prefix)):
return UnquantizedLinearMethod()
return ModelOptFp8LinearMethod(self)
elif isinstance(layer, Attention):
@ -615,6 +625,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"]
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.exclude_modules is not None:
self.exclude_modules = hf_to_vllm_mapper.apply_list(
self.exclude_modules)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
@ -763,7 +778,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if (is_layer_skipped(prefix, self.exclude_modules)
if (is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping)
or self.is_layer_excluded(prefix, self.exclude_modules)):
return UnquantizedLinearMethod()
return ModelOptNvFp4LinearMethod(self)

View File

@ -44,15 +44,16 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsLoRA, SupportsPP,
SupportsQuant)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import (
AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import NemotronHConfig
@ -426,38 +427,36 @@ class NemotronHModel(nn.Module):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
attb_params_mapping = {
"q_proj": "q",
"k_proj": "k",
"v_proj": "v",
}
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "embeddings" in name:
name = name.replace("embeddings", "embed_tokens")
if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if "A_log" in name:
name = name.replace("A_log", "A")
loaded_weight = loaded_weight.to(torch.float32)
# load stacked params
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if "D" in name:
loaded_weight = loaded_weight.to(torch.float32)
if "dt_bias" in name:
loaded_weight = loaded_weight.to(torch.float32)
# load attn params
if any(proj in name for proj in ["q_proj", "k_proj", "v_proj"]):
weight_name = next(proj
for proj in ["q_proj", "k_proj", "v_proj"]
if proj in name)
name = name.replace(weight_name, "qkv_proj")
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight,
attb_params_mapping[weight_name])
weight_loader(param, loaded_weight, shard_id)
break
# load other params
else:
param = params_dict[name]
@ -471,6 +470,14 @@ class NemotronHModel(nn.Module):
class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"backbone": "model"},
orig_to_new_substr={
"A_log": "A",
"embeddings": "embed_tokens"
},
)
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -622,10 +629,5 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
# update name in weights before passing to loader
updated_weights = []
for name, loaded_weight in weights:
name = name.replace("backbone", "model")
updated_weights.append((name, loaded_weight))
loader = AutoWeightsLoader(self)
return loader.load_weights(updated_weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)