[Bugfix][Frontend] Prevent IndexError in MiniMax M2 tool parser during streaming extraction (#30555)

Signed-off-by: WangErXiao <863579016@qq.com>
This commit is contained in:
Robin 2025-12-17 16:37:57 +08:00 committed by GitHub
parent 4f735babb7
commit 20fda43151
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 4 deletions

View File

@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from vllm.tool_parsers.minimax_m2_tool_parser import (
MinimaxM2ToolParser,
)
pytestmark = pytest.mark.cpu_test
class FakeTokenizer:
"""Minimal fake tokenizer that exposes the attributes used by the
parser: a truthy model_tokenizer marker and a vocab mapping for the
special tokens.
"""
def __init__(self):
self.model_tokenizer = True
# The parser will look up start/end tokens by their literal strings
self.vocab = {
"<minimax:tool_call>": 1,
"</minimax:tool_call>": 2,
}
def get_vocab(self):
return self.vocab
@pytest.fixture
def minimax_m2_tool_parser():
return MinimaxM2ToolParser(FakeTokenizer())
def test_extract_tool_calls_streaming_incremental(minimax_m2_tool_parser):
parser = minimax_m2_tool_parser
parser._reset_streaming_state()
chunks = [
"<minimax:tool_call>",
'<invoke name="get_weather">',
'<parameter name="city">',
"Seattle</parameter>",
"</invoke></minimax:tool_call>",
]
previous = ""
for chunk in chunks:
current = previous + chunk
delta = chunk
parser.extract_tool_calls_streaming(
previous_text=previous,
current_text=current,
delta_text=delta,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
previous = current
assert len(parser.prev_tool_call_arr) == 1
entry = parser.prev_tool_call_arr[0]
assert entry["name"] == "get_weather"
args = entry["arguments"]
assert args["city"] == "Seattle"
def test_streaming_minimax_m2_multiple_invokes(minimax_m2_tool_parser):
parser = minimax_m2_tool_parser
parser._reset_streaming_state()
chunks = [
"<minimax:tool_call>",
'<invoke name="search_web">',
'<parameter name="query_tag">',
'["technology", "events"]</parameter>',
'<parameter name="query_list">',
'["OpenAI", "latest", "release"]</parameter>',
"</invoke>",
'<invoke name="search_web">',
'<parameter name="query_tag">',
'["technology", "events"]</parameter>',
'<parameter name="query_list">',
'["Gemini", "latest", "release"]</parameter>',
"</invoke>",
"</minimax:tool_call>",
]
previous = ""
for chunk in chunks:
current = previous + chunk
delta = chunk
parser.extract_tool_calls_streaming(
previous_text=previous,
current_text=current,
delta_text=delta,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
previous = current
assert len(parser.prev_tool_call_arr) == 2
for entry, expect_model in zip(parser.prev_tool_call_arr, ["OpenAI", "Gemini"]):
assert entry["name"] == "search_web"
args = json.dumps(entry["arguments"])
assert "technology" in args and "events" in args
assert expect_model in args
# check streamed_args_for_tool for serving_chat.py
for index in range(2):
expected_call = parser.prev_tool_call_arr[index].get("arguments", {})
expected_call = json.dumps(expected_call)
actual_call = parser.streamed_args_for_tool[index]
assert expected_call == actual_call

View File

@ -122,6 +122,8 @@ class MinimaxM2ToolParser(ToolParser):
self.streaming_request = None self.streaming_request = None
# Clear previous tool call history to avoid state pollution # Clear previous tool call history to avoid state pollution
self.prev_tool_call_arr.clear() self.prev_tool_call_arr.clear()
# Reset streamed args tracking
self.streamed_args_for_tool.clear()
def _extract_name(self, name_str: str) -> str: def _extract_name(self, name_str: str) -> str:
"""Extract name from quoted string.""" """Extract name from quoted string."""
@ -421,9 +423,12 @@ class MinimaxM2ToolParser(ToolParser):
self.prev_tool_call_arr.append( self.prev_tool_call_arr.append(
{ {
"name": self.current_function_name, "name": self.current_function_name,
"arguments": "{}", # Placeholder, will be updated later "arguments": {}, # Placeholder, will be updated later
} }
) )
# Initialize streamed_args_for_tool for this tool call
if len(self.streamed_args_for_tool) <= self.current_tool_index:
self.streamed_args_for_tool.append("")
# Send header with function info # Send header with function info
return DeltaMessage( return DeltaMessage(
@ -445,6 +450,9 @@ class MinimaxM2ToolParser(ToolParser):
# Send opening brace if not sent yet # Send opening brace if not sent yet
if self.in_function and not self.json_started: if self.in_function and not self.json_started:
self.json_started = True self.json_started = True
# Update streamed_args_for_tool for opening brace
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += "{"
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
@ -493,7 +501,7 @@ class MinimaxM2ToolParser(ToolParser):
args = parsed_tool.function.arguments args = parsed_tool.function.arguments
self.prev_tool_call_arr[self.current_tool_index][ self.prev_tool_call_arr[self.current_tool_index][
"arguments" "arguments"
] = args ] = json.loads(args)
except Exception: except Exception:
pass # Ignore parsing errors during streaming pass # Ignore parsing errors during streaming
@ -505,7 +513,9 @@ class MinimaxM2ToolParser(ToolParser):
) )
] ]
) )
# Update streamed_args_for_tool for closing brace
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += "}"
# Reset state for next tool # Reset state for next tool
self.json_closed = True self.json_closed = True
self.in_function = False self.in_function = False
@ -630,7 +640,11 @@ class MinimaxM2ToolParser(ToolParser):
) )
self.param_count += 1 self.param_count += 1
# Update streamed_args_for_tool for this tool call
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += (
json_fragment
)
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(