mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 12:31:19 +08:00
[Bugfix] fix modelopt exclude_modules name mapping (#24178)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
2bef2d1405
commit
08abfa78ec
@ -291,6 +291,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
output_size=self.conv_dim,
|
output_size=self.conv_dim,
|
||||||
bias=use_conv_bias,
|
bias=use_conv_bias,
|
||||||
quant_config=None,
|
quant_config=None,
|
||||||
|
prefix=f"{prefix}.conv1d",
|
||||||
)
|
)
|
||||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||||
# Can't do this in `weight_loader` since it already exists in
|
# Can't do this in `weight_loader` since it already exists in
|
||||||
@ -303,6 +304,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
||||||
bias=use_bias,
|
bias=use_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.in_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
# - because in_proj is a concatenation of 3 weights, we
|
# - because in_proj is a concatenation of 3 weights, we
|
||||||
@ -402,6 +404,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
bias=use_bias,
|
bias=use_bias,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm = Mixer2RMSNormGated(intermediate_size,
|
self.norm = Mixer2RMSNormGated(intermediate_size,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
@ -45,6 +45,9 @@ from vllm.utils import next_power_of_2
|
|||||||
from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer,
|
from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer,
|
||||||
has_flashinfer_moe)
|
has_flashinfer_moe)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
QUANT_ALGOS = ["FP8", "NVFP4"]
|
QUANT_ALGOS = ["FP8", "NVFP4"]
|
||||||
@ -63,7 +66,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||||
self.kv_cache_quant_method = kv_cache_quant_method
|
self.kv_cache_quant_method = kv_cache_quant_method
|
||||||
self.exclude_modules = exclude_modules
|
self.exclude_modules = exclude_modules or []
|
||||||
if is_checkpoint_fp8_serialized:
|
if is_checkpoint_fp8_serialized:
|
||||||
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
|
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
|
||||||
" the format is experimental and could change.")
|
" the format is experimental and could change.")
|
||||||
@ -84,6 +87,11 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
def get_config_filenames(cls) -> list[str]:
|
def get_config_filenames(cls) -> list[str]:
|
||||||
return ["hf_quant_config.json"]
|
return ["hf_quant_config.json"]
|
||||||
|
|
||||||
|
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||||
|
if self.exclude_modules is not None:
|
||||||
|
self.exclude_modules = hf_to_vllm_mapper.apply_list(
|
||||||
|
self.exclude_modules)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(
|
def override_quantization_method(
|
||||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||||
@ -170,7 +178,9 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if self.is_layer_excluded(prefix):
|
if (is_layer_skipped(prefix, self.exclude_modules,
|
||||||
|
self.packed_modules_mapping)
|
||||||
|
or self.is_layer_excluded(prefix)):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return ModelOptFp8LinearMethod(self)
|
return ModelOptFp8LinearMethod(self)
|
||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
@ -615,6 +625,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
def get_config_filenames(cls) -> list[str]:
|
def get_config_filenames(cls) -> list[str]:
|
||||||
return ["hf_quant_config.json"]
|
return ["hf_quant_config.json"]
|
||||||
|
|
||||||
|
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||||
|
if self.exclude_modules is not None:
|
||||||
|
self.exclude_modules = hf_to_vllm_mapper.apply_list(
|
||||||
|
self.exclude_modules)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(
|
def override_quantization_method(
|
||||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||||
@ -763,7 +778,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if (is_layer_skipped(prefix, self.exclude_modules)
|
if (is_layer_skipped(prefix, self.exclude_modules,
|
||||||
|
self.packed_modules_mapping)
|
||||||
or self.is_layer_excluded(prefix, self.exclude_modules)):
|
or self.is_layer_excluded(prefix, self.exclude_modules)):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return ModelOptNvFp4LinearMethod(self)
|
return ModelOptNvFp4LinearMethod(self)
|
||||||
|
|||||||
@ -44,15 +44,16 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||||
SupportsLoRA, SupportsPP,
|
SupportsLoRA, SupportsPP,
|
||||||
SupportsQuant)
|
SupportsQuant)
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||||
MambaCacheParams)
|
MambaCacheParams)
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers,
|
AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory,
|
||||||
maybe_prefix)
|
make_layers, maybe_prefix)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs import NemotronHConfig
|
from vllm.transformers_utils.configs import NemotronHConfig
|
||||||
@ -426,38 +427,36 @@ class NemotronHModel(nn.Module):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
attb_params_mapping = {
|
stacked_params_mapping = [
|
||||||
"q_proj": "q",
|
# (param_name, shard_name, shard_id)
|
||||||
"k_proj": "k",
|
("qkv_proj", "q_proj", "q"),
|
||||||
"v_proj": "v",
|
("qkv_proj", "k_proj", "k"),
|
||||||
}
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "embeddings" in name:
|
if "scale" in name:
|
||||||
name = name.replace("embeddings", "embed_tokens")
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
if "A_log" in name:
|
# load stacked params
|
||||||
name = name.replace("A_log", "A")
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
loaded_weight = loaded_weight.to(torch.float32)
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
if "D" in name:
|
|
||||||
loaded_weight = loaded_weight.to(torch.float32)
|
|
||||||
|
|
||||||
if "dt_bias" in name:
|
|
||||||
loaded_weight = loaded_weight.to(torch.float32)
|
|
||||||
|
|
||||||
# load attn params
|
|
||||||
if any(proj in name for proj in ["q_proj", "k_proj", "v_proj"]):
|
|
||||||
weight_name = next(proj
|
|
||||||
for proj in ["q_proj", "k_proj", "v_proj"]
|
|
||||||
if proj in name)
|
|
||||||
name = name.replace(weight_name, "qkv_proj")
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight,
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
attb_params_mapping[weight_name])
|
break
|
||||||
|
|
||||||
# load other params
|
# load other params
|
||||||
else:
|
else:
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
@ -471,6 +470,14 @@ class NemotronHModel(nn.Module):
|
|||||||
|
|
||||||
class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||||
IsHybrid, SupportsQuant):
|
IsHybrid, SupportsQuant):
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_prefix={"backbone": "model"},
|
||||||
|
orig_to_new_substr={
|
||||||
|
"A_log": "A",
|
||||||
|
"embeddings": "embed_tokens"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -622,10 +629,5 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
# update name in weights before passing to loader
|
|
||||||
updated_weights = []
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
name = name.replace("backbone", "model")
|
|
||||||
updated_weights.append((name, loaded_weight))
|
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(updated_weights)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user