mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:15:31 +08:00
[Bugfix] Fix Mistral tool-parser regex for nested JSON (#20093)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
296ce95d8e
commit
754b00edb3
@ -10,6 +10,7 @@ import pytest
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolCall, MistralToolParser)
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
@ -318,3 +319,53 @@ def test_mistral_guided_decoding(
|
||||
schema=SAMPLE_JSON_SCHEMA)
|
||||
except jsonschema.exceptions.ValidationError:
|
||||
pytest.fail("Generated response is not valid with JSON schema")
|
||||
|
||||
|
||||
def test_mistral_function_call_nested_json():
|
||||
"""Ensure that the function-name regex captures the entire outer-most
|
||||
JSON block, including nested braces."""
|
||||
|
||||
# Create a minimal stub tokenizer that provides the few attributes the
|
||||
# parser accesses (`version` and `get_vocab`).
|
||||
class _StubMistralTokenizer(MistralTokenizer):
|
||||
version = 11 # Satisfy the version check
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_vocab():
|
||||
# Provide the special TOOL_CALLS token expected by the parser.
|
||||
return {"[TOOL_CALLS]": 0}
|
||||
|
||||
tokenizer = _StubMistralTokenizer()
|
||||
parser = MistralToolParser(tokenizer)
|
||||
|
||||
# Craft a model output featuring nested JSON inside the arguments.
|
||||
args_dict = {
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
"sub_dict": {
|
||||
"foo": "bar",
|
||||
"inner": {
|
||||
"x": 1,
|
||||
"y": 2
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
model_output = (
|
||||
f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}")
|
||||
|
||||
parsed = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
# Assertions: the tool call is detected and the full nested JSON is parsed
|
||||
# without truncation.
|
||||
assert parsed.tools_called
|
||||
|
||||
assert MistralToolCall.is_valid_id(parsed.tool_calls[0].id)
|
||||
assert parsed.tool_calls[0].function.name == "get_current_weather"
|
||||
assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict
|
||||
# No additional content outside the tool call should be returned.
|
||||
assert parsed.content is None
|
||||
|
||||
@ -77,8 +77,8 @@ class MistralToolParser(ToolParser):
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
if _is_fn_name_regex_support(self.model_tokenizer):
|
||||
self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})',
|
||||
re.DOTALL)
|
||||
self.fn_name_regex = re.compile(
|
||||
r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL)
|
||||
else:
|
||||
self.fn_name_regex = None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user