diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 074126fa669e..78b186265dd0 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -764,39 +764,41 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: return None return remapped_name - possible_scale_names = [".k_scale", ".v_scale"] - modelopt_scale_names = [ - ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale" + # Define scale name mapping patterns in order of precedence + scale_mapping_patterns = [ + # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + (r"\.self_attn\.([kv])_proj\.([kv])_scale$", + r".self_attn.attn.\2_scale"), + # QKV proj format: .self_attn.qkv_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + (r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"), + # Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + (r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale" + ), + # Default format: .{k,v}_scale -> .attn.{k,v}_scale + (r"\.([kv])_scale$", r".attn.\1_scale"), ] - # Also support qkv_proj scale parameters (from stacked parameter processing) - qkv_proj_scale_names = [ - ".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale" - ] - for scale_name in possible_scale_names: - if name.endswith(scale_name): - if any(mo_scale_name in name - for mo_scale_name in modelopt_scale_names): - remapped_name = name.replace( - f".self_attn.{scale_name[1]}_proj{scale_name}", - f".self_attn.attn{scale_name}") - elif any(qkv_scale_name in name - for qkv_scale_name in qkv_proj_scale_names): - # Handle qkv_proj scale parameters - remapped_name = name.replace( - f".self_attn.qkv_proj{scale_name}", - f".self_attn.attn{scale_name}") - else: - remapped_name = name.replace(scale_name, f".attn{scale_name}") - if remapped_name not in params_dict: - logger.warning_once( - "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501 - scale_name, - name, - remapped_name, - scale_name, - ) - return None - return remapped_name + + # Check if name ends with k_scale or v_scale + if name.endswith((".k_scale", ".v_scale")): + import regex as re + + for pattern, replacement in scale_mapping_patterns: + if re.search(pattern, name): + remapped_name = re.sub(pattern, replacement, name) + if remapped_name not in params_dict: + scale_type = name.split(".")[-1] + logger.warning_once( + "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501 + scale_type, + name, + remapped_name, + scale_type, + ) + return None + return remapped_name # If there were no matches, return the untouched param name return name diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 0e7507a4570b..e4f0de04e9a1 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -408,9 +408,18 @@ class Qwen2Model(nn.Module): continue if is_pp_missing_parameter(name, self): continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 7410589190ba..b2397c115d1d 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -48,7 +48,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -471,12 +472,21 @@ class Qwen3MoeModel(nn.Module): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue if name not in params_dict: continue param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) break else: is_expert_weight = False