[Fix] Fix llama4 modelopt weight loading error (#22107)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
jiahanc 2025-08-03 00:50:34 -07:00 committed by GitHub
parent 2ff46b8826
commit 337eb23bcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -906,11 +906,13 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
"""Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
format."""
if name.startswith("model."):
if name.startswith("model.") or name.startswith(
"language_model.model."):
renamed = name.replace("model.", "language_model.model.",
1) if name.startswith("model.") else name
# Handle expert scale parameters with flat naming
if "feed_forward.experts." in name and ("_input_scale" in name or
"_weight_scale" in name):
renamed = name.replace("model.", "language_model.model.", 1)
# Map checkpoint naming to vLLM's expected naming
if "down_proj_input_scale" in renamed:
return renamed.replace("down_proj_input_scale",
@ -929,7 +931,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
# Handle attention scale parameters
elif "self_attn." in name and (".k_scale" in name
or ".v_scale" in name):
renamed = name.replace("model.", "language_model.model.", 1)
if ".k_proj.k_scale" in renamed:
return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
elif ".v_proj.v_scale" in renamed:
@ -937,7 +938,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
return renamed
# Standard model.* to language_model.model.* renaming
return name.replace("model.", "language_model.model.", 1)
return renamed
elif name.startswith("lm_head.weight"):
return name.replace("lm_head.weight",