[Bugfix][VLM] Fix failing Phi-4-MM multi-images tests and add vision-speech test (#16424)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-04-11 12:57:16 +08:00 committed by GitHub
parent ed37599544
commit 93195146ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 118 additions and 45 deletions

View File

@ -199,13 +199,6 @@ def main(args):
engine_args = asdict(req_data.engine_args) | {"seed": args.seed} engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
# To maintain code compatibility in this script, we add LoRA here.
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
if req_data.lora_requests:
for lora_request in req_data.lora_requests:
llm.llm_engine.add_lora(lora_request=lora_request)
# We set temperature to 0.2 so that outputs can be different # We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference. # even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, sampling_params = SamplingParams(temperature=0.2,
@ -226,8 +219,15 @@ def main(args):
if args.num_prompts > 1: if args.num_prompts > 1:
# Batch inference # Batch inference
inputs = [inputs] * args.num_prompts inputs = [inputs] * args.num_prompts
# Add LoRA request if applicable
lora_request = (req_data.lora_requests *
args.num_prompts if req_data.lora_requests else None)
outputs = llm.generate(inputs, sampling_params=sampling_params) outputs = llm.generate(
inputs,
sampling_params=sampling_params,
lora_request=lora_request,
)
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text

View File

@ -8,6 +8,7 @@ on HuggingFace model repository.
""" """
import os import os
import random import random
from contextlib import contextmanager
from dataclasses import asdict from dataclasses import asdict
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
@ -1055,6 +1056,20 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
return inputs return inputs
@contextmanager
def time_counter(enable: bool):
if enable:
import time
start_time = time.time()
yield
elapsed_time = time.time() - start_time
print("-" * 50)
print("-- generate time = {}".format(elapsed_time))
print("-" * 50)
else:
yield
def main(args): def main(args):
model = args.model_type model = args.model_type
if model not in model_example_map: if model not in model_example_map:
@ -1113,17 +1128,16 @@ def main(args):
}, },
} for i in range(args.num_prompts)] } for i in range(args.num_prompts)]
if args.time_generate: # Add LoRA request if applicable
import time lora_request = (req_data.lora_requests *
start_time = time.time() args.num_prompts if req_data.lora_requests else None)
outputs = llm.generate(inputs, sampling_params=sampling_params)
elapsed_time = time.time() - start_time
print("-" * 50)
print("-- generate time = {}".format(elapsed_time))
print("-" * 50)
else: with time_counter(args.time_generate):
outputs = llm.generate(inputs, sampling_params=sampling_params) outputs = llm.generate(
inputs,
sampling_params=sampling_params,
lora_request=lora_request,
)
print("-" * 50) print("-" * 50)
for o in outputs: for o in outputs:

View File

@ -661,13 +661,6 @@ def run_generate(model, question: str, image_urls: list[str],
engine_args = asdict(req_data.engine_args) | {"seed": args.seed} engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
# To maintain code compatibility in this script, we add LoRA here.
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
if req_data.lora_requests:
for lora_request in req_data.lora_requests:
llm.llm_engine.add_lora(lora_request=lora_request)
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(temperature=0.0,
max_tokens=256, max_tokens=256,
stop_token_ids=req_data.stop_token_ids) stop_token_ids=req_data.stop_token_ids)
@ -679,7 +672,9 @@ def run_generate(model, question: str, image_urls: list[str],
"image": req_data.image_data "image": req_data.image_data
}, },
}, },
sampling_params=sampling_params) sampling_params=sampling_params,
lora_request=req_data.lora_requests,
)
print("-" * 50) print("-" * 50)
for o in outputs: for o in outputs:
@ -724,6 +719,7 @@ def run_chat(model: str, question: str, image_urls: list[str],
}], }],
sampling_params=sampling_params, sampling_params=sampling_params,
chat_template=req_data.chat_template, chat_template=req_data.chat_template,
lora_request=req_data.lora_requests,
) )
print("-" * 50) print("-" * 50)

View File

@ -433,8 +433,8 @@ VLM_TEST_SETTINGS = {
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
task="generate", task="generate",
# use eager mode for hf runner since phi3v didn't work with flash_attn # use sdpa mode for hf runner since phi3v didn't work with flash_attn
hf_model_kwargs={"_attn_implementation": "eager"}, hf_model_kwargs={"_attn_implementation": "sdpa"},
use_tokenizer_eos=True, use_tokenizer_eos=True,
vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output, vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output,
num_logprobs=10, num_logprobs=10,

View File

@ -2,18 +2,22 @@
import os import os
import re import re
from collections.abc import Sequence
from typing import Optional from typing import Optional
import librosa
import pytest import pytest
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.image import rescale_image_size from vllm.multimodal.image import rescale_image_size
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
PromptImageInput, VllmRunner)
from ....utils import large_gpu_test from ....utils import large_gpu_test
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
@ -29,6 +33,8 @@ model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
# Since the vision-lora and speech-lora co-exist with the base model, # Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights. # we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora") vision_lora_path = os.path.join(model_path, "vision-lora")
speech_question = os.path.join(model_path, "examples",
"what_is_shown_in_this_image.wav")
models = [model_path] models = [model_path]
@ -64,7 +70,8 @@ if current_platform.is_rocm():
def run_test( def run_test(
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner], vllm_runner: type[VllmRunner],
inputs: list[tuple[list[str], PromptImageInput]], inputs: Sequence[tuple[list[str], PromptImageInput,
Optional[PromptAudioInput]]],
model: str, model: str,
*, *,
max_model_len: int, max_model_len: int,
@ -104,28 +111,49 @@ def run_test(
enforce_eager=True, enforce_eager=True,
) as vllm_model: ) as vllm_model:
lora_request = LoRARequest("vision", 1, vision_lora_path) lora_request = LoRARequest("vision", 1, vision_lora_path)
vllm_model.model.llm_engine.add_lora(lora_request=lora_request)
vllm_outputs_per_case = [ vllm_outputs_per_case = [
vllm_model.generate_greedy_logprobs(prompts, vllm_model.generate_greedy_logprobs(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=images) images=images,
for prompts, images in inputs audios=audios,
lora_request=lora_request)
for prompts, images, audios in inputs
] ]
# use eager mode for hf runner, since phi3_v didn't work with flash_attn hf_model_kwargs = {"_attn_implementation": "sdpa"}
hf_model_kwargs = {"_attn_implementation": "eager"}
with hf_runner(model, dtype=dtype, with hf_runner(model, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model: model_kwargs=hf_model_kwargs) as hf_model:
eos_token_id = hf_model.processor.tokenizer.eos_token_id
hf_processor = hf_model.processor
eos_token_id = hf_processor.tokenizer.eos_token_id
def patch_hf_processor(*args,
text="",
images=None,
audio=None,
sampling_rate=None,
**kwargs):
audios = None
if audio is not None and sampling_rate is not None:
audios = [(audio, sampling_rate)]
return hf_processor(*args,
text=text,
images=images,
audios=audios,
**kwargs)
hf_model.processor = patch_hf_processor
hf_outputs_per_case = [ hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(prompts, hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=images, images=images,
audios=audios,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
num_logits_to_keep=0) num_logits_to_keep=0)
for prompts, images in inputs for prompts, images, audios in inputs
] ]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
@ -138,8 +166,6 @@ def run_test(
) )
# Since we use _attn_implementation="eager" for hf_runner, there is more
# significant numerical difference. The basic `logprobs=5` fails to pass.
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"size_factors", "size_factors",
@ -151,7 +177,7 @@ def run_test(
# Single-scale, batched # Single-scale, batched
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0],
# Multi-scale # Multi-scale
[0.7, 0.75, 1.0], [0.25, 0.5, 1.0],
], ],
) )
@pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("dtype", [target_dtype])
@ -166,6 +192,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
inputs_per_image = [( inputs_per_image = [(
[prompt for _ in size_factors], [prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors], [rescale_image_size(image, factor) for factor in size_factors],
None,
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
run_test( run_test(
@ -201,17 +228,18 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
@pytest.mark.parametrize("max_model_len", [10000]) @pytest.mark.parametrize("max_model_len", [10000])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10]) @pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.xfail(
reason="Phi-4-MM multi-image inference is divergent with hf model.")
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
size_factors, dtype: str, max_model_len: int, size_factors, dtype: str, max_model_len: int,
max_tokens: int, num_logprobs: int) -> None: max_tokens: int, num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
inputs_per_case = [ inputs_per_case = [
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], (
[HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[[rescale_image_size(image, factor) for image in images] [[rescale_image_size(image, factor) for image in images]
for factor in size_factors]) for factor in size_factors],
None,
),
] ]
run_test( run_test(
@ -226,3 +254,38 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
mm_limit=2, mm_limit=2,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [10000])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
max_model_len: int, max_tokens: int,
num_logprobs: int) -> None:
# use the example speech question so that the model outputs are reasonable
audio = librosa.load(speech_question, sr=None)
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
inputs_vision_speech = [
(
["<|user|><|image_1|><|audio_1|><|end|><|assistant|>"],
[image],
[audio],
),
]
run_test(
hf_runner,
vllm_runner,
inputs_vision_speech,
model,
dtype=dtype,
max_model_len=max_model_len,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)