mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 18:44:28 +08:00
[Fix] Fix llama4 modelopt weight loading error (#22107)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
2ff46b8826
commit
337eb23bcc
@ -906,11 +906,13 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
|
def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
|
||||||
"""Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
|
"""Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
|
||||||
format."""
|
format."""
|
||||||
if name.startswith("model."):
|
if name.startswith("model.") or name.startswith(
|
||||||
|
"language_model.model."):
|
||||||
|
renamed = name.replace("model.", "language_model.model.",
|
||||||
|
1) if name.startswith("model.") else name
|
||||||
# Handle expert scale parameters with flat naming
|
# Handle expert scale parameters with flat naming
|
||||||
if "feed_forward.experts." in name and ("_input_scale" in name or
|
if "feed_forward.experts." in name and ("_input_scale" in name or
|
||||||
"_weight_scale" in name):
|
"_weight_scale" in name):
|
||||||
renamed = name.replace("model.", "language_model.model.", 1)
|
|
||||||
# Map checkpoint naming to vLLM's expected naming
|
# Map checkpoint naming to vLLM's expected naming
|
||||||
if "down_proj_input_scale" in renamed:
|
if "down_proj_input_scale" in renamed:
|
||||||
return renamed.replace("down_proj_input_scale",
|
return renamed.replace("down_proj_input_scale",
|
||||||
@ -929,7 +931,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
# Handle attention scale parameters
|
# Handle attention scale parameters
|
||||||
elif "self_attn." in name and (".k_scale" in name
|
elif "self_attn." in name and (".k_scale" in name
|
||||||
or ".v_scale" in name):
|
or ".v_scale" in name):
|
||||||
renamed = name.replace("model.", "language_model.model.", 1)
|
|
||||||
if ".k_proj.k_scale" in renamed:
|
if ".k_proj.k_scale" in renamed:
|
||||||
return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
|
return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
|
||||||
elif ".v_proj.v_scale" in renamed:
|
elif ".v_proj.v_scale" in renamed:
|
||||||
@ -937,7 +938,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return renamed
|
return renamed
|
||||||
|
|
||||||
# Standard model.* to language_model.model.* renaming
|
# Standard model.* to language_model.model.* renaming
|
||||||
return name.replace("model.", "language_model.model.", 1)
|
return renamed
|
||||||
|
|
||||||
elif name.startswith("lm_head.weight"):
|
elif name.startswith("lm_head.weight"):
|
||||||
return name.replace("lm_head.weight",
|
return name.replace("lm_head.weight",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user