[Bugfix] Fix MTP weight loading (#21941)

This commit is contained in:
Benjamin Chislett 2025-07-31 16:33:53 -04:00 committed by GitHub
parent 71470bc4af
commit 2dff2e21d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -182,6 +182,8 @@ class DeepSeekMTP(nn.Module, SupportsPP):
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
@ -212,6 +214,13 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
if ((param_name == "fused_qkv_a_proj")
and name not in params_dict):
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue