mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:35:32 +08:00
[Bugfix] Fix Mistral tool-parser regex for nested JSON (#20093)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
296ce95d8e
commit
754b00edb3
@ -10,6 +10,7 @@ import pytest
|
|||||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||||
MistralToolCall, MistralToolParser)
|
MistralToolCall, MistralToolParser)
|
||||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||||
|
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
@ -318,3 +319,53 @@ def test_mistral_guided_decoding(
|
|||||||
schema=SAMPLE_JSON_SCHEMA)
|
schema=SAMPLE_JSON_SCHEMA)
|
||||||
except jsonschema.exceptions.ValidationError:
|
except jsonschema.exceptions.ValidationError:
|
||||||
pytest.fail("Generated response is not valid with JSON schema")
|
pytest.fail("Generated response is not valid with JSON schema")
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_function_call_nested_json():
|
||||||
|
"""Ensure that the function-name regex captures the entire outer-most
|
||||||
|
JSON block, including nested braces."""
|
||||||
|
|
||||||
|
# Create a minimal stub tokenizer that provides the few attributes the
|
||||||
|
# parser accesses (`version` and `get_vocab`).
|
||||||
|
class _StubMistralTokenizer(MistralTokenizer):
|
||||||
|
version = 11 # Satisfy the version check
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_vocab():
|
||||||
|
# Provide the special TOOL_CALLS token expected by the parser.
|
||||||
|
return {"[TOOL_CALLS]": 0}
|
||||||
|
|
||||||
|
tokenizer = _StubMistralTokenizer()
|
||||||
|
parser = MistralToolParser(tokenizer)
|
||||||
|
|
||||||
|
# Craft a model output featuring nested JSON inside the arguments.
|
||||||
|
args_dict = {
|
||||||
|
"city": "Dallas",
|
||||||
|
"state": "TX",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
"sub_dict": {
|
||||||
|
"foo": "bar",
|
||||||
|
"inner": {
|
||||||
|
"x": 1,
|
||||||
|
"y": 2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
model_output = (
|
||||||
|
f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}")
|
||||||
|
|
||||||
|
parsed = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
# Assertions: the tool call is detected and the full nested JSON is parsed
|
||||||
|
# without truncation.
|
||||||
|
assert parsed.tools_called
|
||||||
|
|
||||||
|
assert MistralToolCall.is_valid_id(parsed.tool_calls[0].id)
|
||||||
|
assert parsed.tool_calls[0].function.name == "get_current_weather"
|
||||||
|
assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict
|
||||||
|
# No additional content outside the tool call should be returned.
|
||||||
|
assert parsed.content is None
|
||||||
|
|||||||
@ -77,8 +77,8 @@ class MistralToolParser(ToolParser):
|
|||||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||||
if _is_fn_name_regex_support(self.model_tokenizer):
|
if _is_fn_name_regex_support(self.model_tokenizer):
|
||||||
self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})',
|
self.fn_name_regex = re.compile(
|
||||||
re.DOTALL)
|
r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL)
|
||||||
else:
|
else:
|
||||||
self.fn_name_regex = None
|
self.fn_name_regex = None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user