[spec decode] Fix MTP inference path for MiMo-7B model (#25136)

Signed-off-by: zixi-qi <qizixi@meta.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
qizixi 2025-09-18 09:12:19 -07:00 committed by GitHub
parent 1c3b1634aa
commit c4cb0af98a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 6 deletions

View File

@ -53,7 +53,6 @@ def parse_args():
"--method",
type=str,
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp"],
)
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
@ -118,6 +117,11 @@ def main():
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method.endswith("mtp"):
speculative_config = {
"method": args.method,
"num_speculative_tokens": args.num_spec_tokens,
}
else:
raise ValueError(f"unknown method: {args.method}")

View File

@ -31,7 +31,7 @@ logger = init_logger(__name__)
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp"]
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp"]
@config

View File

@ -241,6 +241,15 @@ class MiMoMTP(nn.Module):
def map_model_name_to_mtp_param_name(self, name: str) -> str:
import regex as re
# append mtp_start_layer_idx
pattern = r"(model\.mtp_layers\.)(\d+)(\.)"
match = re.match(pattern, name)
if match:
original_num = int(match.group(2))
new_num = original_num + self.config.num_hidden_layers
name = name.replace(match.group(), f"{match.group(1)}{new_num}.")
# check for early turn
name_without_prefix = [
"token_layernorm", "hidden_layernorm", "input_proj",
"final_layernorm"
@ -248,10 +257,11 @@ class MiMoMTP(nn.Module):
for sub_name in name_without_prefix:
if sub_name in name:
return name
pattern = r"model.mtp_layers.(\d+)."
group = re.match(pattern, name)
if group is not None:
name = name.replace(group.group(), group.group() + "mtp_block.")
# add mtp_block
pattern = r"(model\.mtp_layers\.\d+\.)"
match = re.match(pattern, name)
if match:
name = name.replace(match.group(), match.group() + "mtp_block.")
return name
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: