[Bugfix] Fix mistral model tests (#17181)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-04-25 21:03:34 +08:00 committed by GitHub
parent 7feae92c1f
commit 19dcc02a72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 27 deletions

View File

@ -10,8 +10,8 @@ import jsonschema
import jsonschema.exceptions import jsonschema.exceptions
import pytest import pytest
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolParser) MistralToolCall, MistralToolParser)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
@ -194,7 +194,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
) )
@pytest.mark.skip("RE-ENABLE: test is currently failing on main.")
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@ -246,10 +245,8 @@ def test_mistral_symbolic_languages(vllm_runner, model: str,
assert "<EFBFBD>" not in outputs[0].outputs[0].text.strip() assert "<EFBFBD>" not in outputs[0].outputs[0].text.strip()
@pytest.mark.skip("RE-ENABLE: test is currently failing on main.") @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("model",
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
@ -270,7 +267,8 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
parsed_message = tool_parser.extract_tool_calls(model_output, None) parsed_message = tool_parser.extract_tool_calls(model_output, None)
assert parsed_message.tools_called assert parsed_message.tools_called
assert parsed_message.tool_calls[0].id == "0UAqFzWsD"
assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id)
assert parsed_message.tool_calls[ assert parsed_message.tool_calls[
0].function.name == "get_current_weather" 0].function.name == "get_current_weather"
assert parsed_message.tool_calls[ assert parsed_message.tool_calls[
@ -281,13 +279,23 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend", @pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"]) ["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(vllm_runner, model: str, def test_mistral_guided_decoding(
guided_backend: str) -> None: monkeypatch: pytest.MonkeyPatch,
with vllm_runner(model, dtype='bfloat16', vllm_runner,
tokenizer_mode="mistral") as vllm_model: model: str,
guided_backend: str,
) -> None:
with monkeypatch.context() as m:
# Guided JSON not supported in xgrammar + V1 yet
m.setenv("VLLM_USE_V1", "0")
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA, with vllm_runner(
backend=guided_backend) model,
dtype='bfloat16',
tokenizer_mode="mistral",
guided_decoding_backend=guided_backend,
) as vllm_model:
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA)
params = SamplingParams(max_tokens=512, params = SamplingParams(max_tokens=512,
temperature=0.7, temperature=0.7,
guided_decoding=guided_decoding) guided_decoding=guided_decoding)

View File

@ -38,6 +38,10 @@ class MistralToolCall(ToolCall):
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9)) return "".join(choices(ALPHANUMERIC, k=9))
@staticmethod
def is_valid_id(id: str) -> bool:
return id.isalnum() and len(id) == 9
@ToolParserManager.register_module("mistral") @ToolParserManager.register_module("mistral")
class MistralToolParser(ToolParser): class MistralToolParser(ToolParser):