diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 5af232cb6af6a..004e75b204642 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -53,7 +53,6 @@ def parse_args(): "--method", type=str, default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"], ) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) @@ -118,6 +117,11 @@ def main(): "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, } + elif args.method.endswith("mtp"): + speculative_config = { + "method": args.method, + "num_speculative_tokens": args.num_spec_tokens, + } else: raise ValueError(f"unknown method: {args.method}") diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index fca8c28e5c61e..2c861723c3966 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -31,7 +31,7 @@ logger = init_logger(__name__) SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", "mlp_speculator", "draft_model", "deepseek_mtp", - "ernie_mtp", "qwen3_next_mtp"] + "ernie_mtp", "qwen3_next_mtp", "mimo_mtp"] @config diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index ac835edc001ea..09194e9f95d0e 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -241,6 +241,15 @@ class MiMoMTP(nn.Module): def map_model_name_to_mtp_param_name(self, name: str) -> str: import regex as re + + # append mtp_start_layer_idx + pattern = r"(model\.mtp_layers\.)(\d+)(\.)" + match = re.match(pattern, name) + if match: + original_num = int(match.group(2)) + new_num = original_num + self.config.num_hidden_layers + name = name.replace(match.group(), f"{match.group(1)}{new_num}.") + # check for early turn name_without_prefix = [ "token_layernorm", "hidden_layernorm", "input_proj", "final_layernorm" @@ -248,10 +257,11 @@ class MiMoMTP(nn.Module): for sub_name in name_without_prefix: if sub_name in name: return name - pattern = r"model.mtp_layers.(\d+)." - group = re.match(pattern, name) - if group is not None: - name = name.replace(group.group(), group.group() + "mtp_block.") + # add mtp_block + pattern = r"(model\.mtp_layers\.\d+\.)" + match = re.match(pattern, name) + if match: + name = name.replace(match.group(), match.group() + "mtp_block.") return name def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: