mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 22:57:32 +08:00
[Voxtral] Add more tests (#21010)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
76ddeff293
commit
cfbcb9ed87
@ -804,7 +804,7 @@ class VllmRunner:
|
||||
|
||||
def get_inputs(
|
||||
self,
|
||||
prompts: Union[list[str], list[torch.Tensor]],
|
||||
prompts: Union[list[str], list[torch.Tensor], list[int]],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
@ -826,11 +826,16 @@ class VllmRunner:
|
||||
if audios is not None and (audio := audios[i]) is not None:
|
||||
multi_modal_data["audio"] = audio
|
||||
|
||||
text_prompt_kwargs = {
|
||||
("prompt" if isinstance(prompt, str) else "prompt_embeds"):
|
||||
prompt,
|
||||
text_prompt_kwargs: dict[str, Any] = {
|
||||
"multi_modal_data": multi_modal_data or None
|
||||
}
|
||||
if isinstance(prompt, str):
|
||||
text_prompt_kwargs["prompt"] = prompt
|
||||
elif isinstance(prompt, list):
|
||||
text_prompt_kwargs["prompt_token_ids"] = prompt
|
||||
else:
|
||||
text_prompt_kwargs["prompt_embeds"] = prompt
|
||||
|
||||
inputs.append(TextPrompt(**text_prompt_kwargs))
|
||||
|
||||
return inputs
|
||||
|
||||
@ -47,9 +47,6 @@ async def test_basic_audio(mary_had_lamb, model_name):
|
||||
if model_name.startswith("mistralai"):
|
||||
server_args += MISTRAL_FORMAT_ARGS
|
||||
|
||||
# TODO(PATRICK) - REMOVE AFTER RELEASE
|
||||
return # skip for now
|
||||
|
||||
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
|
||||
115
tests/models/multimodal/generation/test_voxtral.py
Normal file
115
tests/models/multimodal/generation/test_voxtral.py
Normal file
@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio,
|
||||
TextChunk, UserMessage)
|
||||
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
|
||||
from ....conftest import AudioTestAssets
|
||||
from ....utils import RemoteOpenAIServer
|
||||
from .test_ultravox import MULTI_AUDIO_PROMPT, run_multi_audio_test
|
||||
|
||||
MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507"
|
||||
MISTRAL_FORMAT_ARGS = [
|
||||
"--tokenizer_mode", "mistral", "--config_format", "mistral",
|
||||
"--load_format", "mistral"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def server(request, audio_assets: AudioTestAssets):
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
"--limit-mm-per-prompt",
|
||||
json.dumps({"audio": len(audio_assets)}),
|
||||
] + MISTRAL_FORMAT_ARGS
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
args,
|
||||
env_dict={"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||
"30"}) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
def _get_prompt(audio_assets, question):
|
||||
tokenizer = MistralTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(len(audio_assets))
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk]).to_openai()]
|
||||
|
||||
return tokenizer.apply_chat_template(messages=messages)
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models_with_multiple_audios(vllm_runner,
|
||||
audio_assets: AudioTestAssets, dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT)
|
||||
run_multi_audio_test(
|
||||
vllm_runner,
|
||||
[(vllm_prompt, [audio.audio_and_sample_rate
|
||||
for audio in audio_assets])],
|
||||
MODEL_NAME,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tokenizer_mode="mistral",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_online_serving(client, audio_assets: AudioTestAssets):
|
||||
"""Exercises online serving with/without chunked prefill enabled."""
|
||||
|
||||
def asset_to_chunk(asset):
|
||||
audio = Audio.from_file(str(asset.get_local_path()), strict=False)
|
||||
audio.format = "wav"
|
||||
audio_dict = AudioChunk.from_audio(audio).to_openai()
|
||||
return audio_dict
|
||||
|
||||
audio_chunks = [asset_to_chunk(asset) for asset in audio_assets]
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*audio_chunks,
|
||||
{
|
||||
"type":
|
||||
"text",
|
||||
"text":
|
||||
f"What's happening in these {len(audio_assets)} audio clips?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10)
|
||||
|
||||
assert len(chat_completion.choices) == 1
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
@ -440,7 +440,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
|
||||
trust_remote_code=True), # noqa: E501
|
||||
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
||||
"VoxtralForConditionalGeneration": _HfExamplesInfo("mistralai/Voxtral-Mini-3B-2507", is_available_online=False, tokenizer_mode="mistral"), # noqa: E501
|
||||
"VoxtralForConditionalGeneration": _HfExamplesInfo("mistralai/Voxtral-Mini-3B-2507", tokenizer_mode="mistral"), # noqa: E501
|
||||
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
|
||||
|
||||
# [Cross-encoder]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user