diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index b764a940b174..e4d103f7cab9 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -652,9 +652,18 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: return remapped_name possible_scale_names = [".k_scale", ".v_scale"] + modelopt_scale_names = [ + ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale" + ] for scale_name in possible_scale_names: if name.endswith(scale_name): - remapped_name = name.replace(scale_name, f".attn{scale_name}") + if any(mo_scale_name in name + for mo_scale_name in modelopt_scale_names): + remapped_name = name.replace( + f".self_attn.{scale_name[1]}_proj{scale_name}", + f".self_attn.attn{scale_name}") + else: + remapped_name = name.replace(scale_name, f".attn{scale_name}") if remapped_name not in params_dict: logger.warning_once( f"Found {scale_name} in the checkpoint (e.g. {name}), " diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e214c30f5d60..e7c264c04f1a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -404,6 +404,11 @@ class LlamaModel(nn.Module): weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue + if "scale" in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -423,10 +428,6 @@ class LlamaModel(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index da415cdae96e..fbb3704fa080 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -452,7 +452,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id)