mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 00:04:27 +08:00
[Bugfix] Fix mistral model tests (#17181)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
7feae92c1f
commit
19dcc02a72
@ -10,8 +10,8 @@ import jsonschema
|
|||||||
import jsonschema.exceptions
|
import jsonschema.exceptions
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||||
MistralToolParser)
|
MistralToolCall, MistralToolParser)
|
||||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
|
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
@ -194,7 +194,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("RE-ENABLE: test is currently failing on main.")
|
|
||||||
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
|
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@ -246,10 +245,8 @@ def test_mistral_symbolic_languages(vllm_runner, model: str,
|
|||||||
assert "<EFBFBD>" not in outputs[0].outputs[0].text.strip()
|
assert "<EFBFBD>" not in outputs[0].outputs[0].text.strip()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("RE-ENABLE: test is currently failing on main.")
|
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("model",
|
|
||||||
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
|
|
||||||
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
|
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -270,7 +267,8 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
|
|||||||
parsed_message = tool_parser.extract_tool_calls(model_output, None)
|
parsed_message = tool_parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
assert parsed_message.tools_called
|
assert parsed_message.tools_called
|
||||||
assert parsed_message.tool_calls[0].id == "0UAqFzWsD"
|
|
||||||
|
assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id)
|
||||||
assert parsed_message.tool_calls[
|
assert parsed_message.tool_calls[
|
||||||
0].function.name == "get_current_weather"
|
0].function.name == "get_current_weather"
|
||||||
assert parsed_message.tool_calls[
|
assert parsed_message.tool_calls[
|
||||||
@ -281,28 +279,38 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
|
|||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("guided_backend",
|
@pytest.mark.parametrize("guided_backend",
|
||||||
["outlines", "lm-format-enforcer", "xgrammar"])
|
["outlines", "lm-format-enforcer", "xgrammar"])
|
||||||
def test_mistral_guided_decoding(vllm_runner, model: str,
|
def test_mistral_guided_decoding(
|
||||||
guided_backend: str) -> None:
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
with vllm_runner(model, dtype='bfloat16',
|
vllm_runner,
|
||||||
tokenizer_mode="mistral") as vllm_model:
|
model: str,
|
||||||
|
guided_backend: str,
|
||||||
|
) -> None:
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
# Guided JSON not supported in xgrammar + V1 yet
|
||||||
|
m.setenv("VLLM_USE_V1", "0")
|
||||||
|
|
||||||
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
|
with vllm_runner(
|
||||||
backend=guided_backend)
|
model,
|
||||||
params = SamplingParams(max_tokens=512,
|
dtype='bfloat16',
|
||||||
temperature=0.7,
|
tokenizer_mode="mistral",
|
||||||
guided_decoding=guided_decoding)
|
guided_decoding_backend=guided_backend,
|
||||||
|
) as vllm_model:
|
||||||
|
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA)
|
||||||
|
params = SamplingParams(max_tokens=512,
|
||||||
|
temperature=0.7,
|
||||||
|
guided_decoding=guided_decoding)
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "you are a helpful assistant"
|
"content": "you are a helpful assistant"
|
||||||
}, {
|
}, {
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
"content":
|
"content":
|
||||||
f"Give an example JSON for an employee profile that "
|
f"Give an example JSON for an employee profile that "
|
||||||
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
|
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
|
||||||
}]
|
}]
|
||||||
outputs = vllm_model.model.chat(messages, sampling_params=params)
|
outputs = vllm_model.model.chat(messages, sampling_params=params)
|
||||||
|
|
||||||
generated_text = outputs[0].outputs[0].text
|
generated_text = outputs[0].outputs[0].text
|
||||||
json_response = json.loads(generated_text)
|
json_response = json.loads(generated_text)
|
||||||
|
|||||||
@ -38,6 +38,10 @@ class MistralToolCall(ToolCall):
|
|||||||
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
||||||
return "".join(choices(ALPHANUMERIC, k=9))
|
return "".join(choices(ALPHANUMERIC, k=9))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_valid_id(id: str) -> bool:
|
||||||
|
return id.isalnum() and len(id) == 9
|
||||||
|
|
||||||
|
|
||||||
@ToolParserManager.register_module("mistral")
|
@ToolParserManager.register_module("mistral")
|
||||||
class MistralToolParser(ToolParser):
|
class MistralToolParser(ToolParser):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user