diff --git a/tests/models/registry.py b/tests/models/registry.py
index 8e11ee755bf7b..f227edd82761b 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -341,6 +341,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"MiniMaxM1ForCausalLM": _HfExamplesInfo(
"MiniMaxAI/MiniMax-M1-40k", trust_remote_code=True
),
+ "MiniMaxM2ForCausalLM": _HfExamplesInfo(
+ "MiniMaxAI/MiniMax-M2", trust_remote_code=True
+ ),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py
index a72772f59cf2f..4541ca50822f7 100644
--- a/vllm/entrypoints/openai/tool_parsers/__init__.py
+++ b/vllm/entrypoints/openai/tool_parsers/__init__.py
@@ -16,6 +16,7 @@ from .kimi_k2_tool_parser import KimiK2ToolParser
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .longcat_tool_parser import LongcatFlashToolParser
+from .minimax_m2_tool_parser import MinimaxM2ToolParser
from .minimax_tool_parser import MinimaxToolParser
from .mistral_tool_parser import MistralToolParser
from .olmo3_tool_parser import Olmo3PythonicToolParser
@@ -56,4 +57,5 @@ __all__ = [
"SeedOssToolParser",
"Step3ToolParser",
"OpenAIToolParser",
+ "MinimaxM2ToolParser",
]
diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py
new file mode 100644
index 0000000000000..06dd336bf9cf3
--- /dev/null
+++ b/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py
@@ -0,0 +1,644 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import json
+import re
+import uuid
+from collections.abc import Sequence
+from typing import Any
+
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ DeltaFunctionCall,
+ DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall,
+ ToolCall,
+)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser,
+ ToolParserManager,
+)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module("minimax_m2")
+class MinimaxM2ToolParser(ToolParser):
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+
+ self.prev_tool_call_arr: list[dict] = []
+
+ # Sentinel tokens
+ self.tool_call_start_token: str = ""
+ self.tool_call_end_token: str = ""
+ self.invoke_start_prefix: str = ""
+ self.parameter_prefix: str = ""
+
+ # Streaming state variables
+ self.current_tool_name_sent: bool = False
+ # Override base class type - we use string IDs for tool calls
+ self.current_tool_id: str | None = None # type: ignore
+ self.streamed_args_for_tool: list[str] = []
+ self.is_tool_call_started: bool = False
+ self.failed_count: int = 0
+
+ # Initialize streaming state variables
+ self.current_tool_index: int = 0
+ self.invoke_index: int = 0
+ self.header_sent: bool = False
+ self.current_function_name: str | None = None
+ self.current_param_name: str | None = None
+ self.current_param_value: str = ""
+ self.param_count: int = 0
+ self.in_param: bool = False
+ self.in_function: bool = False
+ self.accumulated_text: str = ""
+ self.json_started: bool = False
+ self.json_closed: bool = False
+ self.accumulated_params: dict = {}
+ self.streaming_request: ChatCompletionRequest | None = None
+
+ # Enhanced streaming state - reset for each new message
+ self._reset_streaming_state()
+
+ # Regex patterns for complete parsing
+ self.tool_call_complete_regex = re.compile(
+ r"(.*?)", re.DOTALL
+ )
+ self.invoke_complete_regex = re.compile(
+ r"", re.DOTALL
+ )
+ self.parameter_complete_regex = re.compile(
+ r"", re.DOTALL
+ )
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ToolParser "
+ "constructor during construction."
+ )
+
+ self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
+
+ if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
+ raise RuntimeError(
+ "MiniMax M2 Tool parser could not locate tool call start/end "
+ "tokens in the tokenizer!"
+ )
+
+ logger.info(
+ "vLLM Successfully import tool parser %s !", self.__class__.__name__
+ )
+
+ def _generate_tool_call_id(self) -> str:
+ """Generate a unique tool call ID."""
+ return f"call_{uuid.uuid4().hex[:24]}"
+
+ def _reset_streaming_state(self):
+ """Reset all streaming state."""
+ self.current_tool_index = 0
+ self.invoke_index = 0
+ self.is_tool_call_started = False
+ self.header_sent = False
+ self.current_tool_id = None
+ self.current_function_name = None
+ self.current_param_name = None
+ self.current_param_value = ""
+ self.param_count = 0
+ self.in_param = False
+ self.in_function = False
+ self.accumulated_text = ""
+ self.json_started = False
+ self.json_closed = False
+ # Store accumulated parameters for type conversion
+ self.accumulated_params = {}
+ self.streaming_request = None
+ # Clear previous tool call history to avoid state pollution
+ self.prev_tool_call_arr.clear()
+
+ def _extract_name(self, name_str: str) -> str:
+ """Extract name from quoted string."""
+ name_str = name_str.strip()
+ if (
+ name_str.startswith('"')
+ and name_str.endswith('"')
+ or name_str.startswith("'")
+ and name_str.endswith("'")
+ ):
+ return name_str[1:-1]
+ return name_str
+
+ def _convert_param_value(self, value: str, param_type: str) -> Any:
+ """Convert parameter value to the correct type."""
+ if value.lower() == "null":
+ return None
+
+ param_type = param_type.lower()
+ if param_type in ["string", "str", "text"]:
+ return value
+ elif param_type in ["integer", "int"]:
+ try:
+ return int(value)
+ except (ValueError, TypeError):
+ return value
+ elif param_type in ["number", "float"]:
+ try:
+ val = float(value)
+ return val if val != int(val) else int(val)
+ except (ValueError, TypeError):
+ return value
+ elif param_type in ["boolean", "bool"]:
+ return value.lower() in ["true", "1"]
+ elif param_type in ["object", "array"]:
+ try:
+ return json.loads(value)
+ except json.JSONDecodeError:
+ return value
+ else:
+ # Try JSON parse first, fallback to string
+ try:
+ return json.loads(value)
+ except json.JSONDecodeError:
+ return value
+
+ def _parse_single_invoke(
+ self, invoke_str: str, tools: list | None
+ ) -> ToolCall | None:
+ """Parse a single block."""
+ # Extract function name
+ name_match = re.search(r"^([^>]+)", invoke_str)
+ if not name_match:
+ return None
+
+ function_name = self._extract_name(name_match.group(1))
+
+ # Get parameter configuration
+ param_config = {}
+ if tools:
+ for tool in tools:
+ if (
+ hasattr(tool, "function")
+ and tool.function.name == function_name
+ and hasattr(tool.function, "parameters")
+ ):
+ params = tool.function.parameters
+ if isinstance(params, dict) and "properties" in params:
+ param_config = params["properties"]
+ break
+
+ # Extract parameters
+ param_dict = {}
+ for match in self.parameter_complete_regex.findall(invoke_str):
+ param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL)
+ if param_match:
+ param_name = self._extract_name(param_match.group(1))
+ param_value = param_match.group(2).strip()
+ if param_value.startswith("\n"):
+ param_value = param_value[1:]
+ if param_value.endswith("\n"):
+ param_value = param_value[:-1]
+
+ # Get parameter type
+ param_type = "string"
+ if (
+ param_name in param_config
+ and isinstance(param_config[param_name], dict)
+ and "type" in param_config[param_name]
+ ):
+ param_type = param_config[param_name]["type"]
+
+ # Convert value
+ param_dict[param_name] = self._convert_param_value(
+ param_value, param_type
+ )
+
+ return ToolCall(
+ type="function",
+ function=FunctionCall(
+ name=function_name,
+ arguments=json.dumps(param_dict, ensure_ascii=False),
+ ),
+ )
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+ """Extract tool calls from complete model output (non-streaming)."""
+ # Quick check
+ if self.tool_call_start_token not in model_output:
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+
+ try:
+ tool_calls = []
+
+ # Find all complete tool_call blocks
+ for tool_call_match in self.tool_call_complete_regex.findall(model_output):
+ # Find all invokes within this tool_call
+ for invoke_match in self.invoke_complete_regex.findall(tool_call_match):
+ tool_call = self._parse_single_invoke(
+ invoke_match, request.tools if request else None
+ )
+ if tool_call:
+ tool_calls.append(tool_call)
+
+ if not tool_calls:
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+
+ # Update prev_tool_call_arr
+ self.prev_tool_call_arr.clear()
+ for tool_call in tool_calls:
+ self.prev_tool_call_arr.append(
+ {
+ "name": tool_call.function.name,
+ "arguments": tool_call.function.arguments,
+ }
+ )
+
+ # Extract content before first tool call
+ first_tool_idx = model_output.find(self.tool_call_start_token)
+ content = model_output[:first_tool_idx] if first_tool_idx > 0 else None
+
+ return ExtractedToolCallInformation(
+ tools_called=True, tool_calls=tool_calls, content=content
+ )
+
+ except Exception:
+ logger.exception("Error extracting tool calls")
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int], # pylint: disable=unused-argument
+ current_token_ids: Sequence[int], # pylint: disable=unused-argument
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> DeltaMessage | None:
+ """Extract tool calls from streaming model output."""
+
+ # Store request for type conversion
+ if not previous_text or self.tool_call_start_token in delta_text:
+ self._reset_streaming_state()
+ self.streaming_request = request
+
+ # If no delta text, return None unless it's an EOS token after tools
+ if not delta_text:
+ # Check if this is an EOS token after all tool calls are complete
+ if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
+ # Count complete tool calls
+ complete_calls = len(
+ self.tool_call_complete_regex.findall(current_text)
+ )
+
+ # If we have completed tool calls and populated prev_tool_call_arr
+ if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
+ # Check if all tool calls are closed
+ open_calls = current_text.count(
+ self.tool_call_start_token
+ ) - current_text.count(self.tool_call_end_token)
+ if open_calls == 0:
+ # Return empty delta for finish_reason processing
+ return DeltaMessage(content="")
+ elif not self.is_tool_call_started and current_text:
+ # This is a regular content response that's now complete
+ return DeltaMessage(content="")
+ return None
+
+ # Update accumulated text
+ self.accumulated_text = current_text
+
+ # Check if we need to advance to next tool
+ if self.json_closed and not self.in_function:
+ # Check if this tool call has ended
+ invoke_ends = current_text.count(self.invoke_end_token)
+ if invoke_ends > self.current_tool_index:
+ # This tool has ended, advance to next
+ self.current_tool_index += 1
+ self.header_sent = False
+ self.param_count = 0
+ self.json_started = False
+ self.json_closed = False
+ self.in_function = False # Now we can safely set this to False
+ self.accumulated_params = {}
+ # Continue processing next tool
+ return None
+
+ # Handle normal content before tool calls
+ if not self.is_tool_call_started:
+ # Check if tool call is starting
+ if (
+ self.tool_call_start_token_id in delta_token_ids
+ or self.tool_call_start_token in delta_text
+ ):
+ self.is_tool_call_started = True
+ # Return any content before the tool call
+ if self.tool_call_start_token in delta_text:
+ content_before = delta_text[
+ : delta_text.index(self.tool_call_start_token)
+ ]
+ if content_before:
+ return DeltaMessage(content=content_before)
+ return None
+ else:
+ # Check if we're between tool calls - skip whitespace
+ if (
+ current_text.rstrip().endswith(self.tool_call_end_token)
+ and delta_text.strip() == ""
+ ):
+ # We just ended a tool call, skip whitespace
+ return None
+ # Normal content, no tool call
+ return DeltaMessage(content=delta_text)
+
+ # Check if we're between tool calls (waiting for next one)
+ invoke_starts_count = current_text.count(self.invoke_start_prefix)
+ if self.current_tool_index >= invoke_starts_count:
+ # We're past all tool calls, shouldn't be here
+ return None
+
+ # Find the current tool call portion
+ invoke_start_positions: list[int] = []
+ idx = 0
+ while True:
+ idx = current_text.find(self.invoke_start_prefix, idx)
+ if idx == -1:
+ break
+ invoke_start_positions.append(idx)
+ idx += len(self.invoke_start_prefix)
+
+ if self.current_tool_index >= len(invoke_start_positions):
+ # No more tool calls to process yet
+ return None
+
+ invoke_start_idx = invoke_start_positions[self.current_tool_index]
+ # Find where this tool call ends (or current position if not ended yet)
+ invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx)
+ if invoke_end_idx == -1:
+ tool_text = current_text[invoke_start_idx:]
+ else:
+ tool_text = current_text[
+ invoke_start_idx : invoke_end_idx + len(self.invoke_end_token)
+ ]
+
+ # Looking for function header
+ if not self.header_sent:
+ if self.invoke_start_prefix in tool_text:
+ func_start = tool_text.find(self.invoke_start_prefix) + len(
+ self.invoke_start_prefix
+ )
+ # Find the end quote for the function name
+ func_end = tool_text.find(">", func_start)
+
+ if func_end != -1:
+ # Found complete function name
+ function_name_raw = tool_text[func_start:func_end]
+ self.current_function_name = self._extract_name(function_name_raw)
+ self.current_tool_id = self._generate_tool_call_id()
+ self.header_sent = True
+ self.in_function = True
+
+ # Add to prev_tool_call_arr immediately when we detect a tool call
+ # Each tool call should be recorded regardless of function name
+ # Ensure we don't add the same tool call index multiple times
+ if len(self.prev_tool_call_arr) <= self.current_tool_index:
+ self.prev_tool_call_arr.append(
+ {
+ "name": self.current_function_name,
+ "arguments": "{}", # Placeholder, will be updated later
+ }
+ )
+
+ # Send header with function info
+ return DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_index,
+ id=self.current_tool_id,
+ function=DeltaFunctionCall(
+ name=self.current_function_name, arguments=""
+ ),
+ type="function",
+ )
+ ]
+ )
+ return None
+
+ # We've sent header, now handle function body
+ if self.in_function:
+ # Send opening brace if not sent yet
+ if self.in_function and not self.json_started:
+ self.json_started = True
+ return DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_index,
+ function=DeltaFunctionCall(arguments="{"),
+ )
+ ]
+ )
+
+ # Make sure json_started is set if we're processing parameters
+ if not self.json_started:
+ self.json_started = True
+
+ # Check for function end in accumulated text
+ if not self.json_closed and self.invoke_end_token in tool_text:
+ # Count total parameters in the tool text
+ total_param_count = tool_text.count(self.parameter_prefix)
+
+ # Only close JSON if all parameters have been processed
+ if self.param_count >= total_param_count:
+ # Close JSON
+ self.json_closed = True
+
+ # Extract complete tool call
+ # Find the invoke content
+ invoke_start = tool_text.find(self.invoke_start_prefix) + len(
+ self.invoke_start_prefix
+ )
+ invoke_content_end = tool_text.find(
+ self.invoke_end_token, invoke_start
+ )
+ if invoke_content_end != -1:
+ invoke_content = tool_text[invoke_start:invoke_content_end]
+ # Parse to get the complete arguments
+ try:
+ parsed_tool = self._parse_single_invoke(
+ invoke_content,
+ self.streaming_request.tools
+ if self.streaming_request
+ else None,
+ )
+ if parsed_tool and self.current_tool_index < len(
+ self.prev_tool_call_arr
+ ):
+ # Update existing entry in prev_tool_call_arr
+ args = parsed_tool.function.arguments
+ self.prev_tool_call_arr[self.current_tool_index][
+ "arguments"
+ ] = args
+ except Exception:
+ pass # Ignore parsing errors during streaming
+
+ result = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_index,
+ function=DeltaFunctionCall(arguments="}"),
+ )
+ ]
+ )
+
+ # Reset state for next tool
+ self.json_closed = True
+ self.in_function = False
+ self.accumulated_params = {}
+
+ logger.debug("[M2_STREAMING] Tool call completed")
+
+ return result
+ else:
+ # Don't close JSON yet, continue processing parameters
+ return None
+
+ # Look for parameters
+ # Find all parameter starts
+ param_starts = []
+ idx = 0
+ while True:
+ idx = tool_text.find(self.parameter_prefix, idx)
+ if idx == -1:
+ break
+ param_starts.append(idx)
+ idx += len(self.parameter_prefix)
+
+ # Check if we should start a new parameter
+ if (
+ not self.in_param
+ and self.param_count < len(param_starts)
+ and len(param_starts) > self.param_count
+ ):
+ # Process the next parameter
+ param_idx = param_starts[self.param_count]
+ param_start = param_idx + len(self.parameter_prefix)
+ remaining = tool_text[param_start:]
+
+ if ">" in remaining:
+ # We have the complete parameter name
+ name_end = remaining.find(">")
+ param_name_raw = remaining[:name_end]
+ self.current_param_name = self._extract_name(param_name_raw)
+
+ # Find the parameter value
+ value_start = param_start + name_end + 1
+ value_text = tool_text[value_start:]
+ if value_text.startswith("\n"):
+ value_text = value_text[1:]
+
+ # Find where this parameter ends
+ param_end_idx = value_text.find(self.parameter_end_token)
+ if param_end_idx == -1:
+ # No closing tag, look for next parameter or function end
+ next_param_idx = value_text.find(self.parameter_prefix)
+ func_end_idx = value_text.find(self.invoke_end_token)
+
+ if next_param_idx != -1 and (
+ func_end_idx == -1 or next_param_idx < func_end_idx
+ ):
+ param_end_idx = next_param_idx
+ elif func_end_idx != -1:
+ param_end_idx = func_end_idx
+ else:
+ # Neither found, check if tool call is complete
+ if self.invoke_end_token in tool_text:
+ # Tool call and parameter is complete
+ param_end_idx = len(value_text)
+ else:
+ # Still streaming, wait for more content
+ return None
+
+ if param_end_idx != -1:
+ # Complete parameter found
+ param_value = value_text[:param_end_idx]
+ if param_value.endswith("\n"):
+ param_value = param_value[:-1]
+
+ # Store raw value for later processing
+ self.accumulated_params[self.current_param_name] = param_value
+
+ # Get parameter configuration for type conversion
+ param_config = {}
+ if self.streaming_request and self.streaming_request.tools:
+ for tool in self.streaming_request.tools:
+ if (
+ hasattr(tool, "function")
+ and tool.function.name == self.current_function_name
+ and hasattr(tool.function, "parameters")
+ ):
+ params = tool.function.parameters
+ if (
+ isinstance(params, dict)
+ and "properties" in params
+ ):
+ param_config = params["properties"]
+ break
+
+ # Get parameter type
+ param_type = "string"
+ if (
+ self.current_param_name in param_config
+ and isinstance(param_config[self.current_param_name], dict)
+ and "type" in param_config[self.current_param_name]
+ ):
+ param_type = param_config[self.current_param_name]["type"]
+
+ # Convert param value to appropriate type
+ converted_value = self._convert_param_value(
+ param_value, param_type
+ )
+
+ # Build JSON fragment based on the converted type
+ # Use json.dumps to properly serialize the value
+ serialized_value = json.dumps(
+ converted_value, ensure_ascii=False
+ )
+
+ if self.param_count == 0:
+ json_fragment = (
+ f'"{self.current_param_name}": {serialized_value}'
+ )
+ else:
+ json_fragment = (
+ f', "{self.current_param_name}": {serialized_value}'
+ )
+
+ self.param_count += 1
+
+ return DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_index,
+ function=DeltaFunctionCall(arguments=json_fragment),
+ )
+ ]
+ )
+
+ return None
diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py
new file mode 100644
index 0000000000000..d122adfafbac0
--- /dev/null
+++ b/vllm/model_executor/models/minimax_m2.py
@@ -0,0 +1,585 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2023 The vLLM team.
+# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only MiniMaxM2 model."""
+
+from collections.abc import Iterable
+from typing import Any
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from vllm.attention import Attention
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, ModelConfig, VllmConfig
+from vllm.distributed import (
+ get_pp_group,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce,
+)
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader,
+ maybe_remap_kv_scale_name,
+)
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import SupportsPP
+from .utils import (
+ AutoWeightsLoader,
+ PPMissingLayer,
+ is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory,
+ make_layers,
+ maybe_prefix,
+)
+
+
+class MiniMaxM2MoE(nn.Module):
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+
+ if self.tp_size > config.num_local_experts:
+ raise ValueError(
+ f"Tensor parallel size {self.tp_size} is greater than "
+ f"the number of experts {config.num_local_experts}."
+ )
+ self.use_routing_bias = getattr(config, "use_routing_bias", False)
+ if self.use_routing_bias:
+ self.e_score_correction_bias = nn.Parameter(
+ torch.empty(config.num_local_experts, dtype=torch.float32)
+ )
+ self.e_score_correction_bias.weight_loader = (
+ MiniMaxM2MoE.ebias_weight_loader
+ )
+ else:
+ self.e_score_correction_bias = None
+
+ self.experts = FusedMoE(
+ num_experts=config.num_local_experts,
+ top_k=config.num_experts_per_tok,
+ scoring_func=config.scoring_func,
+ use_grouped_topk=True,
+ num_expert_group=1,
+ topk_group=1,
+ e_score_correction_bias=self.e_score_correction_bias,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ reduce_results=False,
+ renormalize=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.experts",
+ )
+
+ self.gate = ReplicatedLinear(
+ config.hidden_size,
+ config.num_local_experts,
+ bias=False,
+ params_dtype=torch.float32,
+ quant_config=None,
+ prefix=f"{prefix}.gate",
+ )
+
+ @staticmethod
+ def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
+ assert param.size() == loaded_weight.size()
+ param.data.copy_(loaded_weight.to(torch.float32))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ num_tokens, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ # router_logits: (num_tokens, n_experts)
+ router_logits, _ = self.gate(hidden_states.to(torch.float32))
+ final_hidden_states = self.experts(
+ hidden_states=hidden_states, router_logits=router_logits
+ )
+ final_hidden_states = final_hidden_states
+ if self.tp_size > 1:
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
+
+ return final_hidden_states.view(num_tokens, hidden_dim)
+
+
+class MiniMaxM2Attention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ rotary_dim: int,
+ rope_theta: float = 10000,
+ rope_scaling: dict[str, Any] | None = None,
+ attn_window_size: int | None = None,
+ max_position_embeddings: int = 8192,
+ head_dim: int | None = None,
+ rms_norm_eps: float = 1e-06,
+ qkv_bias: bool = False,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = head_dim or (hidden_size // self.total_num_heads)
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=rotary_dim,
+ max_position=max_position_embeddings,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ per_layer_sliding_window=attn_window_size,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ )
+
+ self.q_norm = MiniMaxText01RMSNormTP(
+ self.head_dim * self.total_num_heads, eps=rms_norm_eps
+ )
+ self.k_norm = MiniMaxText01RMSNormTP(
+ self.head_dim * self.total_num_kv_heads, eps=rms_norm_eps
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class MiniMaxM2DecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ prefix: str,
+ model_config: ModelConfig,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ rope_theta = getattr(config, "rope_theta", 10000)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
+ if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
+ max_position_embeddings = max(
+ config.max_position_embeddings, config.max_model_len
+ )
+ # DecoderLayers are created with `make_layers` which passes the prefix
+ # with the layer's index.
+ layer_idx = int(prefix.split(sep=".")[-1])
+
+ # TODO: support MTP
+ attn_window_size = getattr(config, "attn_window_size", None)
+ if attn_window_size is not None:
+ if isinstance(attn_window_size, list):
+ attn_window_size = attn_window_size[layer_idx]
+ elif isinstance(attn_window_size, int):
+ attn_window_size = attn_window_size
+ else:
+ raise ValueError(f"Invalid attn_window_size: {attn_window_size}")
+ attn_window_size = None if attn_window_size <= 0 else attn_window_size
+
+ # different rope theta for full layer and swa layer
+ swa_rope_theta = getattr(config, "swa_rope_theta", -1)
+ # default to full rope theta
+ swa_rope_theta = rope_theta if swa_rope_theta <= 0 else swa_rope_theta
+ rope_theta = swa_rope_theta if attn_window_size is not None else rope_theta
+
+ self.layer_idx = layer_idx
+ self.self_attn = MiniMaxM2Attention(
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ rotary_dim=config.rotary_dim,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ attn_window_size=attn_window_size,
+ max_position_embeddings=max_position_embeddings,
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=getattr(config, "attention_bias", False),
+ head_dim=getattr(config, "head_dim", None),
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ )
+
+ self.block_sparse_moe = MiniMaxM2MoE(
+ config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ ) -> torch.Tensor:
+ # Self Attention
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ )
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+
+ hidden_states = self.block_sparse_moe(hidden_states)
+
+ return hidden_states, residual
+
+
+@support_torch_compile
+class MiniMaxM2Model(nn.Module):
+ fall_back_to_pt_during_load = False
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ model_config = vllm_config.model_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+
+ self.vocab_size = config.vocab_size
+
+ if get_pp_group().is_first_rank:
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=None,
+ prefix=f"{prefix}.embed_tokens",
+ )
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: MiniMaxM2DecoderLayer(
+ config,
+ prefix,
+ model_config=model_config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ ),
+ prefix=f"{prefix}.layers",
+ )
+
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size
+ )
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor | IntermediateTensors:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for layer in self.layers[self.start_layer : self.end_layer]:
+ hidden_states, residual = layer(positions, hidden_states, residual)
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors(
+ {"hidden_states": hidden_states, "residual": residual}
+ )
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ return FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="w1",
+ ckpt_down_proj_name="w2",
+ ckpt_up_proj_name="w3",
+ num_experts=self.config.num_local_experts,
+ )
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ]
+
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ expert_params_mapping = self.get_expert_mapping()
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+
+ spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
+ if spec_layer is not None:
+ continue # skip spec decode layers for main model
+
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
+ if weight_name not in name:
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if ("mlp.experts." in name) and name not in params_dict:
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(
+ param,
+ loaded_weight,
+ name,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ )
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class MiniMaxM2ForCausalLM(nn.Module, SupportsPP):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+ if hasattr(vllm_config.model_config, "max_model_len"):
+ self.config.max_model_len = vllm_config.model_config.max_model_len
+ self.model = MiniMaxM2Model(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
+ )
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size, config.hidden_size, quant_config=None
+ )
+ else:
+ self.lm_head = PPMissingLayer()
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors
+ )
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs,
+ ) -> torch.Tensor | IntermediateTensors:
+ hidden_states = self.model(
+ input_ids, positions, intermediate_tensors, inputs_embeds
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor | None:
+ logits = self.logits_processor(self.lm_head, hidden_states)
+ return logits
+
+ def make_empty_intermediate_tensors(
+ self, batch_size: int, dtype: torch.dtype, device: torch.device
+ ) -> IntermediateTensors:
+ return IntermediateTensors(
+ {
+ "hidden_states": torch.zeros(
+ (batch_size, self.config.hidden_size), dtype=dtype, device=device
+ ),
+ "residual": torch.zeros(
+ (batch_size, self.config.hidden_size), dtype=dtype, device=device
+ ),
+ }
+ )
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights)
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ return self.model.get_expert_mapping()
+
+
+def get_spec_layer_idx_from_weight_name(
+ config: PretrainedConfig, weight_name: str
+) -> int | None:
+ if hasattr(config, "num_mtp_modules") and (config.num_mtp_modules > 0):
+ layer_idx = config.num_hidden_layers
+ for i in range(config.num_mtp_modules):
+ if weight_name.startswith(f"model.layers.{layer_idx + i}."):
+ return layer_idx + i
+ return None
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 81d4a6bc5f3a7..e8212ef6d72d8 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -131,6 +131,7 @@ _TEXT_GENERATION_MODELS = {
"MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
+ "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
# transformers's mpt class has lower case
diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py
index ecee1af439028..3d666882efb59 100644
--- a/vllm/reasoning/__init__.py
+++ b/vllm/reasoning/__init__.py
@@ -11,6 +11,7 @@ from .gptoss_reasoning_parser import GptOssReasoningParser
from .granite_reasoning_parser import GraniteReasoningParser
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
from .identity_reasoning_parser import IdentityReasoningParser
+from .minimax_m2_reasoning_parser import MiniMaxM2ReasoningParser
from .mistral_reasoning_parser import MistralReasoningParser
from .olmo3_reasoning_parser import Olmo3ReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser
@@ -34,4 +35,5 @@ __all__ = [
"Step3ReasoningParser",
"GptOssReasoningParser",
"SeedOSSReasoningParser",
+ "MiniMaxM2ReasoningParser",
]
diff --git a/vllm/reasoning/minimax_m2_reasoning_parser.py b/vllm/reasoning/minimax_m2_reasoning_parser.py
new file mode 100644
index 0000000000000..0d4f6cc270a1c
--- /dev/null
+++ b/vllm/reasoning/minimax_m2_reasoning_parser.py
@@ -0,0 +1,69 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Sequence
+
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ DeltaMessage,
+ ResponsesRequest,
+)
+from vllm.logger import init_logger
+from vllm.reasoning.abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
+from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+logger = init_logger(__name__)
+
+
+@ReasoningParserManager.register_module("minimax_m2")
+class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser):
+ """
+ Reasoning parser for MiniMax M2 model.
+ """
+
+ @property
+ def start_token(self) -> str:
+ """The token that starts reasoning content."""
+ return ""
+
+ @property
+ def end_token(self) -> str:
+ """The token that ends reasoning content."""
+ return ""
+
+
+@ReasoningParserManager.register_module("minimax_m2_append_think")
+class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
+ """
+ Reasoning parser for MiniMax M2 model.
+ """
+
+ def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
+ super().__init__(tokenizer, *args, **kwargs)
+ self.end_token_id = self.vocab.get("")
+
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
+ end_token_id = self.end_token_id
+ return any(input_id == end_token_id for input_id in reversed(input_ids))
+
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
+ return input_ids
+
+ def extract_reasoning_content_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ ) -> DeltaMessage | None:
+ if len(previous_token_ids) == 0:
+ delta_text = "" + delta_text
+ return DeltaMessage(content=delta_text)
+
+ def extract_reasoning_content(
+ self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
+ ) -> tuple[str | None, str | None]:
+ return None, "" + model_output