mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 05:05:46 +08:00
[Model][MiniMax-M2] Support MiniMax-M2 Model (#27535)
Signed-off-by: xuebi <xuebi@minimaxi.com> Co-authored-by: xuebi <xuebi@minimaxi.com>
This commit is contained in:
parent
55cba4a05c
commit
720af6ab79
@ -341,6 +341,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"MiniMaxM1ForCausalLM": _HfExamplesInfo(
|
||||
"MiniMaxAI/MiniMax-M1-40k", trust_remote_code=True
|
||||
),
|
||||
"MiniMaxM2ForCausalLM": _HfExamplesInfo(
|
||||
"MiniMaxAI/MiniMax-M2", trust_remote_code=True
|
||||
),
|
||||
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
|
||||
"MixtralForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
|
||||
@ -16,6 +16,7 @@ from .kimi_k2_tool_parser import KimiK2ToolParser
|
||||
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .longcat_tool_parser import LongcatFlashToolParser
|
||||
from .minimax_m2_tool_parser import MinimaxM2ToolParser
|
||||
from .minimax_tool_parser import MinimaxToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
from .olmo3_tool_parser import Olmo3PythonicToolParser
|
||||
@ -56,4 +57,5 @@ __all__ = [
|
||||
"SeedOssToolParser",
|
||||
"Step3ToolParser",
|
||||
"OpenAIToolParser",
|
||||
"MinimaxM2ToolParser",
|
||||
]
|
||||
|
||||
644
vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py
Normal file
644
vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py
Normal file
@ -0,0 +1,644 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
ToolParserManager,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("minimax_m2")
|
||||
class MinimaxM2ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
|
||||
# Sentinel tokens
|
||||
self.tool_call_start_token: str = "<minimax:tool_call>"
|
||||
self.tool_call_end_token: str = "</minimax:tool_call>"
|
||||
self.invoke_start_prefix: str = "<invoke name="
|
||||
self.invoke_end_token: str = "</invoke>"
|
||||
self.parameter_prefix: str = "<parameter name="
|
||||
self.parameter_end_token: str = "</parameter>"
|
||||
|
||||
# Streaming state variables
|
||||
self.current_tool_name_sent: bool = False
|
||||
# Override base class type - we use string IDs for tool calls
|
||||
self.current_tool_id: str | None = None # type: ignore
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.is_tool_call_started: bool = False
|
||||
self.failed_count: int = 0
|
||||
|
||||
# Initialize streaming state variables
|
||||
self.current_tool_index: int = 0
|
||||
self.invoke_index: int = 0
|
||||
self.header_sent: bool = False
|
||||
self.current_function_name: str | None = None
|
||||
self.current_param_name: str | None = None
|
||||
self.current_param_value: str = ""
|
||||
self.param_count: int = 0
|
||||
self.in_param: bool = False
|
||||
self.in_function: bool = False
|
||||
self.accumulated_text: str = ""
|
||||
self.json_started: bool = False
|
||||
self.json_closed: bool = False
|
||||
self.accumulated_params: dict = {}
|
||||
self.streaming_request: ChatCompletionRequest | None = None
|
||||
|
||||
# Enhanced streaming state - reset for each new message
|
||||
self._reset_streaming_state()
|
||||
|
||||
# Regex patterns for complete parsing
|
||||
self.tool_call_complete_regex = re.compile(
|
||||
r"<minimax:tool_call>(.*?)</minimax:tool_call>", re.DOTALL
|
||||
)
|
||||
self.invoke_complete_regex = re.compile(
|
||||
r"<invoke name=(.*?)</invoke>", re.DOTALL
|
||||
)
|
||||
self.parameter_complete_regex = re.compile(
|
||||
r"<parameter name=(.*?)</parameter>", re.DOTALL
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"MiniMax M2 Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"vLLM Successfully import tool parser %s !", self.__class__.__name__
|
||||
)
|
||||
|
||||
def _generate_tool_call_id(self) -> str:
|
||||
"""Generate a unique tool call ID."""
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def _reset_streaming_state(self):
|
||||
"""Reset all streaming state."""
|
||||
self.current_tool_index = 0
|
||||
self.invoke_index = 0
|
||||
self.is_tool_call_started = False
|
||||
self.header_sent = False
|
||||
self.current_tool_id = None
|
||||
self.current_function_name = None
|
||||
self.current_param_name = None
|
||||
self.current_param_value = ""
|
||||
self.param_count = 0
|
||||
self.in_param = False
|
||||
self.in_function = False
|
||||
self.accumulated_text = ""
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
# Store accumulated parameters for type conversion
|
||||
self.accumulated_params = {}
|
||||
self.streaming_request = None
|
||||
# Clear previous tool call history to avoid state pollution
|
||||
self.prev_tool_call_arr.clear()
|
||||
|
||||
def _extract_name(self, name_str: str) -> str:
|
||||
"""Extract name from quoted string."""
|
||||
name_str = name_str.strip()
|
||||
if (
|
||||
name_str.startswith('"')
|
||||
and name_str.endswith('"')
|
||||
or name_str.startswith("'")
|
||||
and name_str.endswith("'")
|
||||
):
|
||||
return name_str[1:-1]
|
||||
return name_str
|
||||
|
||||
def _convert_param_value(self, value: str, param_type: str) -> Any:
|
||||
"""Convert parameter value to the correct type."""
|
||||
if value.lower() == "null":
|
||||
return None
|
||||
|
||||
param_type = param_type.lower()
|
||||
if param_type in ["string", "str", "text"]:
|
||||
return value
|
||||
elif param_type in ["integer", "int"]:
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
elif param_type in ["number", "float"]:
|
||||
try:
|
||||
val = float(value)
|
||||
return val if val != int(val) else int(val)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
elif param_type in ["boolean", "bool"]:
|
||||
return value.lower() in ["true", "1"]
|
||||
elif param_type in ["object", "array"]:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
else:
|
||||
# Try JSON parse first, fallback to string
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
def _parse_single_invoke(
|
||||
self, invoke_str: str, tools: list | None
|
||||
) -> ToolCall | None:
|
||||
"""Parse a single <invoke> block."""
|
||||
# Extract function name
|
||||
name_match = re.search(r"^([^>]+)", invoke_str)
|
||||
if not name_match:
|
||||
return None
|
||||
|
||||
function_name = self._extract_name(name_match.group(1))
|
||||
|
||||
# Get parameter configuration
|
||||
param_config = {}
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if (
|
||||
hasattr(tool, "function")
|
||||
and tool.function.name == function_name
|
||||
and hasattr(tool.function, "parameters")
|
||||
):
|
||||
params = tool.function.parameters
|
||||
if isinstance(params, dict) and "properties" in params:
|
||||
param_config = params["properties"]
|
||||
break
|
||||
|
||||
# Extract parameters
|
||||
param_dict = {}
|
||||
for match in self.parameter_complete_regex.findall(invoke_str):
|
||||
param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL)
|
||||
if param_match:
|
||||
param_name = self._extract_name(param_match.group(1))
|
||||
param_value = param_match.group(2).strip()
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
# Get parameter type
|
||||
param_type = "string"
|
||||
if (
|
||||
param_name in param_config
|
||||
and isinstance(param_config[param_name], dict)
|
||||
and "type" in param_config[param_name]
|
||||
):
|
||||
param_type = param_config[param_name]["type"]
|
||||
|
||||
# Convert value
|
||||
param_dict[param_name] = self._convert_param_value(
|
||||
param_value, param_type
|
||||
)
|
||||
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=json.dumps(param_dict, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""Extract tool calls from complete model output (non-streaming)."""
|
||||
# Quick check
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
tool_calls = []
|
||||
|
||||
# Find all complete tool_call blocks
|
||||
for tool_call_match in self.tool_call_complete_regex.findall(model_output):
|
||||
# Find all invokes within this tool_call
|
||||
for invoke_match in self.invoke_complete_regex.findall(tool_call_match):
|
||||
tool_call = self._parse_single_invoke(
|
||||
invoke_match, request.tools if request else None
|
||||
)
|
||||
if tool_call:
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if not tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# Update prev_tool_call_arr
|
||||
self.prev_tool_call_arr.clear()
|
||||
for tool_call in tool_calls:
|
||||
self.prev_tool_call_arr.append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
}
|
||||
)
|
||||
|
||||
# Extract content before first tool call
|
||||
first_tool_idx = model_output.find(self.tool_call_start_token)
|
||||
content = model_output[:first_tool_idx] if first_tool_idx > 0 else None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True, tool_calls=tool_calls, content=content
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error extracting tool calls")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
current_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract tool calls from streaming model output."""
|
||||
|
||||
# Store request for type conversion
|
||||
if not previous_text or self.tool_call_start_token in delta_text:
|
||||
self._reset_streaming_state()
|
||||
self.streaming_request = request
|
||||
|
||||
# If no delta text, return None unless it's an EOS token after tools
|
||||
if not delta_text:
|
||||
# Check if this is an EOS token after all tool calls are complete
|
||||
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
|
||||
# Count complete tool calls
|
||||
complete_calls = len(
|
||||
self.tool_call_complete_regex.findall(current_text)
|
||||
)
|
||||
|
||||
# If we have completed tool calls and populated prev_tool_call_arr
|
||||
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
|
||||
# Check if all tool calls are closed
|
||||
open_calls = current_text.count(
|
||||
self.tool_call_start_token
|
||||
) - current_text.count(self.tool_call_end_token)
|
||||
if open_calls == 0:
|
||||
# Return empty delta for finish_reason processing
|
||||
return DeltaMessage(content="")
|
||||
elif not self.is_tool_call_started and current_text:
|
||||
# This is a regular content response that's now complete
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
|
||||
# Update accumulated text
|
||||
self.accumulated_text = current_text
|
||||
|
||||
# Check if we need to advance to next tool
|
||||
if self.json_closed and not self.in_function:
|
||||
# Check if this tool call has ended
|
||||
invoke_ends = current_text.count(self.invoke_end_token)
|
||||
if invoke_ends > self.current_tool_index:
|
||||
# This tool has ended, advance to next
|
||||
self.current_tool_index += 1
|
||||
self.header_sent = False
|
||||
self.param_count = 0
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
self.in_function = False # Now we can safely set this to False
|
||||
self.accumulated_params = {}
|
||||
# Continue processing next tool
|
||||
return None
|
||||
|
||||
# Handle normal content before tool calls
|
||||
if not self.is_tool_call_started:
|
||||
# Check if tool call is starting
|
||||
if (
|
||||
self.tool_call_start_token_id in delta_token_ids
|
||||
or self.tool_call_start_token in delta_text
|
||||
):
|
||||
self.is_tool_call_started = True
|
||||
# Return any content before the tool call
|
||||
if self.tool_call_start_token in delta_text:
|
||||
content_before = delta_text[
|
||||
: delta_text.index(self.tool_call_start_token)
|
||||
]
|
||||
if content_before:
|
||||
return DeltaMessage(content=content_before)
|
||||
return None
|
||||
else:
|
||||
# Check if we're between tool calls - skip whitespace
|
||||
if (
|
||||
current_text.rstrip().endswith(self.tool_call_end_token)
|
||||
and delta_text.strip() == ""
|
||||
):
|
||||
# We just ended a tool call, skip whitespace
|
||||
return None
|
||||
# Normal content, no tool call
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Check if we're between tool calls (waiting for next one)
|
||||
invoke_starts_count = current_text.count(self.invoke_start_prefix)
|
||||
if self.current_tool_index >= invoke_starts_count:
|
||||
# We're past all tool calls, shouldn't be here
|
||||
return None
|
||||
|
||||
# Find the current tool call portion
|
||||
invoke_start_positions: list[int] = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = current_text.find(self.invoke_start_prefix, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
invoke_start_positions.append(idx)
|
||||
idx += len(self.invoke_start_prefix)
|
||||
|
||||
if self.current_tool_index >= len(invoke_start_positions):
|
||||
# No more tool calls to process yet
|
||||
return None
|
||||
|
||||
invoke_start_idx = invoke_start_positions[self.current_tool_index]
|
||||
# Find where this tool call ends (or current position if not ended yet)
|
||||
invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx)
|
||||
if invoke_end_idx == -1:
|
||||
tool_text = current_text[invoke_start_idx:]
|
||||
else:
|
||||
tool_text = current_text[
|
||||
invoke_start_idx : invoke_end_idx + len(self.invoke_end_token)
|
||||
]
|
||||
|
||||
# Looking for function header
|
||||
if not self.header_sent:
|
||||
if self.invoke_start_prefix in tool_text:
|
||||
func_start = tool_text.find(self.invoke_start_prefix) + len(
|
||||
self.invoke_start_prefix
|
||||
)
|
||||
# Find the end quote for the function name
|
||||
func_end = tool_text.find(">", func_start)
|
||||
|
||||
if func_end != -1:
|
||||
# Found complete function name
|
||||
function_name_raw = tool_text[func_start:func_end]
|
||||
self.current_function_name = self._extract_name(function_name_raw)
|
||||
self.current_tool_id = self._generate_tool_call_id()
|
||||
self.header_sent = True
|
||||
self.in_function = True
|
||||
|
||||
# Add to prev_tool_call_arr immediately when we detect a tool call
|
||||
# Each tool call should be recorded regardless of function name
|
||||
# Ensure we don't add the same tool call index multiple times
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_index:
|
||||
self.prev_tool_call_arr.append(
|
||||
{
|
||||
"name": self.current_function_name,
|
||||
"arguments": "{}", # Placeholder, will be updated later
|
||||
}
|
||||
)
|
||||
|
||||
# Send header with function info
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
id=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=self.current_function_name, arguments=""
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
]
|
||||
)
|
||||
return None
|
||||
|
||||
# We've sent header, now handle function body
|
||||
if self.in_function:
|
||||
# Send opening brace if not sent yet
|
||||
if self.in_function and not self.json_started:
|
||||
self.json_started = True
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="{"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Make sure json_started is set if we're processing parameters
|
||||
if not self.json_started:
|
||||
self.json_started = True
|
||||
|
||||
# Check for function end in accumulated text
|
||||
if not self.json_closed and self.invoke_end_token in tool_text:
|
||||
# Count total parameters in the tool text
|
||||
total_param_count = tool_text.count(self.parameter_prefix)
|
||||
|
||||
# Only close JSON if all parameters have been processed
|
||||
if self.param_count >= total_param_count:
|
||||
# Close JSON
|
||||
self.json_closed = True
|
||||
|
||||
# Extract complete tool call
|
||||
# Find the invoke content
|
||||
invoke_start = tool_text.find(self.invoke_start_prefix) + len(
|
||||
self.invoke_start_prefix
|
||||
)
|
||||
invoke_content_end = tool_text.find(
|
||||
self.invoke_end_token, invoke_start
|
||||
)
|
||||
if invoke_content_end != -1:
|
||||
invoke_content = tool_text[invoke_start:invoke_content_end]
|
||||
# Parse to get the complete arguments
|
||||
try:
|
||||
parsed_tool = self._parse_single_invoke(
|
||||
invoke_content,
|
||||
self.streaming_request.tools
|
||||
if self.streaming_request
|
||||
else None,
|
||||
)
|
||||
if parsed_tool and self.current_tool_index < len(
|
||||
self.prev_tool_call_arr
|
||||
):
|
||||
# Update existing entry in prev_tool_call_arr
|
||||
args = parsed_tool.function.arguments
|
||||
self.prev_tool_call_arr[self.current_tool_index][
|
||||
"arguments"
|
||||
] = args
|
||||
except Exception:
|
||||
pass # Ignore parsing errors during streaming
|
||||
|
||||
result = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="}"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Reset state for next tool
|
||||
self.json_closed = True
|
||||
self.in_function = False
|
||||
self.accumulated_params = {}
|
||||
|
||||
logger.debug("[M2_STREAMING] Tool call completed")
|
||||
|
||||
return result
|
||||
else:
|
||||
# Don't close JSON yet, continue processing parameters
|
||||
return None
|
||||
|
||||
# Look for parameters
|
||||
# Find all parameter starts
|
||||
param_starts = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = tool_text.find(self.parameter_prefix, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
param_starts.append(idx)
|
||||
idx += len(self.parameter_prefix)
|
||||
|
||||
# Check if we should start a new parameter
|
||||
if (
|
||||
not self.in_param
|
||||
and self.param_count < len(param_starts)
|
||||
and len(param_starts) > self.param_count
|
||||
):
|
||||
# Process the next parameter
|
||||
param_idx = param_starts[self.param_count]
|
||||
param_start = param_idx + len(self.parameter_prefix)
|
||||
remaining = tool_text[param_start:]
|
||||
|
||||
if ">" in remaining:
|
||||
# We have the complete parameter name
|
||||
name_end = remaining.find(">")
|
||||
param_name_raw = remaining[:name_end]
|
||||
self.current_param_name = self._extract_name(param_name_raw)
|
||||
|
||||
# Find the parameter value
|
||||
value_start = param_start + name_end + 1
|
||||
value_text = tool_text[value_start:]
|
||||
if value_text.startswith("\n"):
|
||||
value_text = value_text[1:]
|
||||
|
||||
# Find where this parameter ends
|
||||
param_end_idx = value_text.find(self.parameter_end_token)
|
||||
if param_end_idx == -1:
|
||||
# No closing tag, look for next parameter or function end
|
||||
next_param_idx = value_text.find(self.parameter_prefix)
|
||||
func_end_idx = value_text.find(self.invoke_end_token)
|
||||
|
||||
if next_param_idx != -1 and (
|
||||
func_end_idx == -1 or next_param_idx < func_end_idx
|
||||
):
|
||||
param_end_idx = next_param_idx
|
||||
elif func_end_idx != -1:
|
||||
param_end_idx = func_end_idx
|
||||
else:
|
||||
# Neither found, check if tool call is complete
|
||||
if self.invoke_end_token in tool_text:
|
||||
# Tool call and parameter is complete
|
||||
param_end_idx = len(value_text)
|
||||
else:
|
||||
# Still streaming, wait for more content
|
||||
return None
|
||||
|
||||
if param_end_idx != -1:
|
||||
# Complete parameter found
|
||||
param_value = value_text[:param_end_idx]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
# Store raw value for later processing
|
||||
self.accumulated_params[self.current_param_name] = param_value
|
||||
|
||||
# Get parameter configuration for type conversion
|
||||
param_config = {}
|
||||
if self.streaming_request and self.streaming_request.tools:
|
||||
for tool in self.streaming_request.tools:
|
||||
if (
|
||||
hasattr(tool, "function")
|
||||
and tool.function.name == self.current_function_name
|
||||
and hasattr(tool.function, "parameters")
|
||||
):
|
||||
params = tool.function.parameters
|
||||
if (
|
||||
isinstance(params, dict)
|
||||
and "properties" in params
|
||||
):
|
||||
param_config = params["properties"]
|
||||
break
|
||||
|
||||
# Get parameter type
|
||||
param_type = "string"
|
||||
if (
|
||||
self.current_param_name in param_config
|
||||
and isinstance(param_config[self.current_param_name], dict)
|
||||
and "type" in param_config[self.current_param_name]
|
||||
):
|
||||
param_type = param_config[self.current_param_name]["type"]
|
||||
|
||||
# Convert param value to appropriate type
|
||||
converted_value = self._convert_param_value(
|
||||
param_value, param_type
|
||||
)
|
||||
|
||||
# Build JSON fragment based on the converted type
|
||||
# Use json.dumps to properly serialize the value
|
||||
serialized_value = json.dumps(
|
||||
converted_value, ensure_ascii=False
|
||||
)
|
||||
|
||||
if self.param_count == 0:
|
||||
json_fragment = (
|
||||
f'"{self.current_param_name}": {serialized_value}'
|
||||
)
|
||||
else:
|
||||
json_fragment = (
|
||||
f', "{self.current_param_name}": {serialized_value}'
|
||||
)
|
||||
|
||||
self.param_count += 1
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments=json_fragment),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
585
vllm/model_executor/models/minimax_m2.py
Normal file
585
vllm/model_executor/models/minimax_m2.py
Normal file
@ -0,0 +1,585 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniMaxM2 model."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
|
||||
class MiniMaxM2MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if self.tp_size > config.num_local_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_local_experts}."
|
||||
)
|
||||
self.use_routing_bias = getattr(config, "use_routing_bias", False)
|
||||
if self.use_routing_bias:
|
||||
self.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.num_local_experts, dtype=torch.float32)
|
||||
)
|
||||
self.e_score_correction_bias.weight_loader = (
|
||||
MiniMaxM2MoE.ebias_weight_loader
|
||||
)
|
||||
else:
|
||||
self.e_score_correction_bias = None
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
scoring_func=config.scoring_func,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=1,
|
||||
topk_group=1,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_local_experts,
|
||||
bias=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight.to(torch.float32))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states.to(torch.float32))
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
final_hidden_states = final_hidden_states
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class MiniMaxM2Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rotary_dim: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: dict[str, Any] | None = None,
|
||||
attn_window_size: int | None = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
head_dim: int | None = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
qkv_bias: bool = False,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
per_layer_sliding_window=attn_window_size,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
self.q_norm = MiniMaxText01RMSNormTP(
|
||||
self.head_dim * self.total_num_heads, eps=rms_norm_eps
|
||||
)
|
||||
self.k_norm = MiniMaxText01RMSNormTP(
|
||||
self.head_dim * self.total_num_kv_heads, eps=rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MiniMaxM2DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
|
||||
max_position_embeddings = max(
|
||||
config.max_position_embeddings, config.max_model_len
|
||||
)
|
||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep=".")[-1])
|
||||
|
||||
# TODO: support MTP
|
||||
attn_window_size = getattr(config, "attn_window_size", None)
|
||||
if attn_window_size is not None:
|
||||
if isinstance(attn_window_size, list):
|
||||
attn_window_size = attn_window_size[layer_idx]
|
||||
elif isinstance(attn_window_size, int):
|
||||
attn_window_size = attn_window_size
|
||||
else:
|
||||
raise ValueError(f"Invalid attn_window_size: {attn_window_size}")
|
||||
attn_window_size = None if attn_window_size <= 0 else attn_window_size
|
||||
|
||||
# different rope theta for full layer and swa layer
|
||||
swa_rope_theta = getattr(config, "swa_rope_theta", -1)
|
||||
# default to full rope theta
|
||||
swa_rope_theta = rope_theta if swa_rope_theta <= 0 else swa_rope_theta
|
||||
rope_theta = swa_rope_theta if attn_window_size is not None else rope_theta
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.self_attn = MiniMaxM2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rotary_dim=config.rotary_dim,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
attn_window_size=attn_window_size,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, "attention_bias", False),
|
||||
head_dim=getattr(config, "head_dim", None),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
self.block_sparse_moe = MiniMaxM2MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.block_sparse_moe(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MiniMaxM2Model(nn.Module):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MiniMaxM2DecoderLayer(
|
||||
config,
|
||||
prefix,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in self.layers[self.start_layer : self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
ckpt_up_proj_name="w3",
|
||||
num_experts=self.config.num_local_experts,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MiniMaxM2ForCausalLM(nn.Module, SupportsPP):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
if hasattr(vllm_config.model_config, "max_model_len"):
|
||||
self.config.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.model = MiniMaxM2Model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=None
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype, device: torch.device
|
||||
) -> IntermediateTensors:
|
||||
return IntermediateTensors(
|
||||
{
|
||||
"hidden_states": torch.zeros(
|
||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
||||
),
|
||||
"residual": torch.zeros(
|
||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(
|
||||
config: PretrainedConfig, weight_name: str
|
||||
) -> int | None:
|
||||
if hasattr(config, "num_mtp_modules") and (config.num_mtp_modules > 0):
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_mtp_modules):
|
||||
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
|
||||
return layer_idx + i
|
||||
return None
|
||||
@ -131,6 +131,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
||||
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
||||
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
||||
"MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
|
||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||
# transformers's mpt class has lower case
|
||||
|
||||
@ -11,6 +11,7 @@ from .gptoss_reasoning_parser import GptOssReasoningParser
|
||||
from .granite_reasoning_parser import GraniteReasoningParser
|
||||
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||
from .identity_reasoning_parser import IdentityReasoningParser
|
||||
from .minimax_m2_reasoning_parser import MiniMaxM2ReasoningParser
|
||||
from .mistral_reasoning_parser import MistralReasoningParser
|
||||
from .olmo3_reasoning_parser import Olmo3ReasoningParser
|
||||
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
||||
@ -34,4 +35,5 @@ __all__ = [
|
||||
"Step3ReasoningParser",
|
||||
"GptOssReasoningParser",
|
||||
"SeedOSSReasoningParser",
|
||||
"MiniMaxM2ReasoningParser",
|
||||
]
|
||||
|
||||
69
vllm/reasoning/minimax_m2_reasoning_parser.py
Normal file
69
vllm/reasoning/minimax_m2_reasoning_parser.py
Normal file
@ -0,0 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("minimax_m2")
|
||||
class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser):
|
||||
"""
|
||||
Reasoning parser for MiniMax M2 model.
|
||||
"""
|
||||
|
||||
@property
|
||||
def start_token(self) -> str:
|
||||
"""The token that starts reasoning content."""
|
||||
return "<think>"
|
||||
|
||||
@property
|
||||
def end_token(self) -> str:
|
||||
"""The token that ends reasoning content."""
|
||||
return "</think>"
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("minimax_m2_append_think")
|
||||
class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for MiniMax M2 model.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
self.end_token_id = self.vocab.get("</think>")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
end_token_id = self.end_token_id
|
||||
return any(input_id == end_token_id for input_id in reversed(input_ids))
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
return input_ids
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
if len(previous_token_ids) == 0:
|
||||
delta_text = "<think>" + delta_text
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
|
||||
) -> tuple[str | None, str | None]:
|
||||
return None, "<think>" + model_output
|
||||
Loading…
x
Reference in New Issue
Block a user