mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 01:42:14 +08:00
[Mistral-Small 3.1] Update docs and tests (#14977)
Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
400d483e87
commit
f863ffc965
@ -879,7 +879,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
- * `PixtralForConditionalGeneration`
|
- * `PixtralForConditionalGeneration`
|
||||||
* Pixtral
|
* Pixtral
|
||||||
* T + I<sup>+</sup>
|
* T + I<sup>+</sup>
|
||||||
* `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b`, etc.
|
* `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc.
|
||||||
*
|
*
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
|||||||
@ -6,14 +6,14 @@ import argparse
|
|||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
# This script is an offline demo for running Pixtral.
|
# This script is an offline demo for running Mistral-Small-3
|
||||||
#
|
#
|
||||||
# If you want to run a server/client setup, please follow this code:
|
# If you want to run a server/client setup, please follow this code:
|
||||||
#
|
#
|
||||||
# - Server:
|
# - Server:
|
||||||
#
|
#
|
||||||
# ```bash
|
# ```bash
|
||||||
# vllm serve mistralai/Pixtral-12B-2409 --tokenizer-mode mistral --limit-mm-per-prompt 'image=4' --max-model-len 16384
|
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 --tokenizer-mode mistral --limit-mm-per-prompt 'image=4' --max-model-len 16384
|
||||||
# ```
|
# ```
|
||||||
#
|
#
|
||||||
# - Client:
|
# - Client:
|
||||||
@ -23,7 +23,7 @@ from vllm.sampling_params import SamplingParams
|
|||||||
# --header 'Content-Type: application/json' \
|
# --header 'Content-Type: application/json' \
|
||||||
# --header 'Authorization: Bearer token' \
|
# --header 'Authorization: Bearer token' \
|
||||||
# --data '{
|
# --data '{
|
||||||
# "model": "mistralai/Pixtral-12B-2409",
|
# "model": "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||||
# "messages": [
|
# "messages": [
|
||||||
# {
|
# {
|
||||||
# "role": "user",
|
# "role": "user",
|
||||||
@ -44,7 +44,7 @@ from vllm.sampling_params import SamplingParams
|
|||||||
|
|
||||||
|
|
||||||
def run_simple_demo(args: argparse.Namespace):
|
def run_simple_demo(args: argparse.Namespace):
|
||||||
model_name = "mistralai/Pixtral-12B-2409"
|
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
sampling_params = SamplingParams(max_tokens=8192)
|
sampling_params = SamplingParams(max_tokens=8192)
|
||||||
|
|
||||||
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
|
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
|
||||||
@ -83,7 +83,7 @@ def run_simple_demo(args: argparse.Namespace):
|
|||||||
|
|
||||||
|
|
||||||
def run_advanced_demo(args: argparse.Namespace):
|
def run_advanced_demo(args: argparse.Namespace):
|
||||||
model_name = "mistralai/Pixtral-12B-2409"
|
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
max_img_per_msg = 5
|
max_img_per_msg = 5
|
||||||
max_tokens_per_img = 4096
|
max_tokens_per_img = 4096
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@
|
|||||||
Run `pytest tests/models/test_mistral.py`.
|
Run `pytest tests/models/test_mistral.py`.
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
@ -16,8 +15,7 @@ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
|||||||
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams,
|
from vllm import RequestOutput, SamplingParams, TextPrompt, TokensPrompt
|
||||||
TextPrompt, TokensPrompt)
|
|
||||||
from vllm.multimodal import MultiModalDataBuiltins
|
from vllm.multimodal import MultiModalDataBuiltins
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
from vllm.multimodal.inputs import PlaceholderRange
|
||||||
from vllm.sequence import Logprob, SampleLogprobs
|
from vllm.sequence import Logprob, SampleLogprobs
|
||||||
@ -28,7 +26,11 @@ from ...utils import check_logprobs_close
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from _typeshed import StrPath
|
from _typeshed import StrPath
|
||||||
|
|
||||||
MODELS = ["mistralai/Pixtral-12B-2409"]
|
PIXTRAL_ID = "mistralai/Pixtral-12B-2409"
|
||||||
|
MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
|
|
||||||
|
MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID]
|
||||||
|
|
||||||
IMG_URLS = [
|
IMG_URLS = [
|
||||||
"https://picsum.photos/id/237/400/300",
|
"https://picsum.photos/id/237/400/300",
|
||||||
"https://picsum.photos/id/231/200/300",
|
"https://picsum.photos/id/231/200/300",
|
||||||
@ -125,8 +127,10 @@ MAX_MODEL_LEN = [8192, 65536]
|
|||||||
FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures"
|
FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures"
|
||||||
assert FIXTURES_PATH.exists()
|
assert FIXTURES_PATH.exists()
|
||||||
|
|
||||||
FIXTURE_LOGPROBS_CHAT = FIXTURES_PATH / "pixtral_chat.json"
|
FIXTURE_LOGPROBS_CHAT = {
|
||||||
FIXTURE_LOGPROBS_ENGINE = FIXTURES_PATH / "pixtral_chat_engine.json"
|
PIXTRAL_ID: FIXTURES_PATH / "pixtral_chat.json",
|
||||||
|
MISTRAL_SMALL_3_1_ID: FIXTURES_PATH / "mistral_small_3_chat.json",
|
||||||
|
}
|
||||||
|
|
||||||
OutputsLogprobs = list[tuple[list[int], str, Optional[SampleLogprobs]]]
|
OutputsLogprobs = list[tuple[list[int], str, Optional[SampleLogprobs]]]
|
||||||
|
|
||||||
@ -166,12 +170,12 @@ def test_chat(
|
|||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
|
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(
|
||||||
|
FIXTURE_LOGPROBS_CHAT[model])
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tokenizer_mode="mistral",
|
tokenizer_mode="mistral",
|
||||||
enable_chunked_prefill=False,
|
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
@ -183,70 +187,40 @@ def test_chat(
|
|||||||
outputs.extend(output)
|
outputs.extend(output)
|
||||||
|
|
||||||
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||||
|
# Remove last `None` prompt_logprobs to compare with fixture
|
||||||
|
for i in range(len(logprobs)):
|
||||||
|
assert logprobs[i][-1] is None
|
||||||
|
logprobs[i] = logprobs[i][:-1]
|
||||||
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
|
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
|
||||||
outputs_1_lst=logprobs,
|
outputs_1_lst=logprobs,
|
||||||
name_0="h100_ref",
|
name_0="h100_ref",
|
||||||
name_1="output")
|
name_1="output")
|
||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=80)
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
|
||||||
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
|
||||||
EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
|
|
||||||
args = EngineArgs(
|
|
||||||
model=model,
|
|
||||||
tokenizer_mode="mistral",
|
|
||||||
enable_chunked_prefill=False,
|
|
||||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
engine = LLMEngine.from_engine_args(args)
|
|
||||||
|
|
||||||
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
|
|
||||||
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
count = 0
|
|
||||||
while True:
|
|
||||||
out = engine.step()
|
|
||||||
count += 1
|
|
||||||
for request_output in out:
|
|
||||||
if request_output.finished:
|
|
||||||
outputs.append(request_output)
|
|
||||||
|
|
||||||
if count == 2:
|
|
||||||
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
|
|
||||||
SAMPLING_PARAMS)
|
|
||||||
if not engine.has_unfinished_requests():
|
|
||||||
break
|
|
||||||
|
|
||||||
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
|
||||||
check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
|
|
||||||
outputs_1_lst=logprobs,
|
|
||||||
name_0="h100_ref",
|
|
||||||
name_1="output")
|
|
||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=48)
|
@large_gpu_test(min_gb=48)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"prompt,expected_ranges",
|
"prompt,expected_ranges",
|
||||||
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{
|
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{
|
||||||
"offset": 10,
|
"offset": 11,
|
||||||
"length": 494
|
"length": 494
|
||||||
}]),
|
}]),
|
||||||
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{
|
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{
|
||||||
"offset": 10,
|
"offset": 11,
|
||||||
"length": 266
|
"length": 266
|
||||||
}, {
|
}, {
|
||||||
"offset": 276,
|
"offset": 277,
|
||||||
"length": 1056
|
"length": 1056
|
||||||
}, {
|
}, {
|
||||||
"offset": 1332,
|
"offset": 1333,
|
||||||
"length": 418
|
"length": 418
|
||||||
}])])
|
}])])
|
||||||
def test_multi_modal_placeholders(
|
def test_multi_modal_placeholders(vllm_runner, prompt,
|
||||||
vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None:
|
expected_ranges: list[PlaceholderRange],
|
||||||
|
monkeypatch) -> None:
|
||||||
|
|
||||||
|
# This placeholder checking test only works with V0 engine
|
||||||
|
# where `multi_modal_placeholders` is returned with `RequestOutput`
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
"mistral-community/pixtral-12b",
|
"mistral-community/pixtral-12b",
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
|
|||||||
1
tests/models/fixtures/mistral_small_3_chat.json
Normal file
1
tests/models/fixtures/mistral_small_3_chat.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user