mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 01:57:02 +08:00
[spec decode] Fix MTP inference path for MiMo-7B model (#25136)
Signed-off-by: zixi-qi <qizixi@meta.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
1c3b1634aa
commit
c4cb0af98a
@ -53,7 +53,6 @@ def parse_args():
|
||||
"--method",
|
||||
type=str,
|
||||
default="eagle",
|
||||
choices=["ngram", "eagle", "eagle3", "mtp"],
|
||||
)
|
||||
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||
@ -118,6 +117,11 @@ def main():
|
||||
"prompt_lookup_max": args.prompt_lookup_max,
|
||||
"prompt_lookup_min": args.prompt_lookup_min,
|
||||
}
|
||||
elif args.method.endswith("mtp"):
|
||||
speculative_config = {
|
||||
"method": args.method,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"unknown method: {args.method}")
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ logger = init_logger(__name__)
|
||||
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||
"mlp_speculator", "draft_model", "deepseek_mtp",
|
||||
"ernie_mtp", "qwen3_next_mtp"]
|
||||
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp"]
|
||||
|
||||
|
||||
@config
|
||||
|
||||
@ -241,6 +241,15 @@ class MiMoMTP(nn.Module):
|
||||
|
||||
def map_model_name_to_mtp_param_name(self, name: str) -> str:
|
||||
import regex as re
|
||||
|
||||
# append mtp_start_layer_idx
|
||||
pattern = r"(model\.mtp_layers\.)(\d+)(\.)"
|
||||
match = re.match(pattern, name)
|
||||
if match:
|
||||
original_num = int(match.group(2))
|
||||
new_num = original_num + self.config.num_hidden_layers
|
||||
name = name.replace(match.group(), f"{match.group(1)}{new_num}.")
|
||||
# check for early turn
|
||||
name_without_prefix = [
|
||||
"token_layernorm", "hidden_layernorm", "input_proj",
|
||||
"final_layernorm"
|
||||
@ -248,10 +257,11 @@ class MiMoMTP(nn.Module):
|
||||
for sub_name in name_without_prefix:
|
||||
if sub_name in name:
|
||||
return name
|
||||
pattern = r"model.mtp_layers.(\d+)."
|
||||
group = re.match(pattern, name)
|
||||
if group is not None:
|
||||
name = name.replace(group.group(), group.group() + "mtp_block.")
|
||||
# add mtp_block
|
||||
pattern = r"(model\.mtp_layers\.\d+\.)"
|
||||
match = re.match(pattern, name)
|
||||
if match:
|
||||
name = name.replace(match.group(), match.group() + "mtp_block.")
|
||||
return name
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user