[Bugfix] Fix Mistral tool-parser regex for nested JSON (#20093)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-06-26 10:01:17 +09:00 committed by GitHub
parent 296ce95d8e
commit 754b00edb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 2 deletions

View File

@ -10,6 +10,7 @@ import pytest
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall, MistralToolParser)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.transformers_utils.tokenizer import MistralTokenizer
from ...utils import check_logprobs_close
@ -318,3 +319,53 @@ def test_mistral_guided_decoding(
schema=SAMPLE_JSON_SCHEMA)
except jsonschema.exceptions.ValidationError:
pytest.fail("Generated response is not valid with JSON schema")
def test_mistral_function_call_nested_json():
"""Ensure that the function-name regex captures the entire outer-most
JSON block, including nested braces."""
# Create a minimal stub tokenizer that provides the few attributes the
# parser accesses (`version` and `get_vocab`).
class _StubMistralTokenizer(MistralTokenizer):
version = 11 # Satisfy the version check
def __init__(self):
pass
@staticmethod
def get_vocab():
# Provide the special TOOL_CALLS token expected by the parser.
return {"[TOOL_CALLS]": 0}
tokenizer = _StubMistralTokenizer()
parser = MistralToolParser(tokenizer)
# Craft a model output featuring nested JSON inside the arguments.
args_dict = {
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
"sub_dict": {
"foo": "bar",
"inner": {
"x": 1,
"y": 2
}
},
}
model_output = (
f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}")
parsed = parser.extract_tool_calls(model_output, None)
# Assertions: the tool call is detected and the full nested JSON is parsed
# without truncation.
assert parsed.tools_called
assert MistralToolCall.is_valid_id(parsed.tool_calls[0].id)
assert parsed.tool_calls[0].function.name == "get_current_weather"
assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict
# No additional content outside the tool call should be returned.
assert parsed.content is None

View File

@ -77,8 +77,8 @@ class MistralToolParser(ToolParser):
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if _is_fn_name_regex_support(self.model_tokenizer):
self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})',
re.DOTALL)
self.fn_name_regex = re.compile(
r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL)
else:
self.fn_name_regex = None