mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[Bugfix][VLM] Fix failing Phi-4-MM multi-images tests and add vision-speech test (#16424)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
ed37599544
commit
93195146ea
@ -199,13 +199,6 @@ def main(args):
|
|||||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||||
llm = LLM(**engine_args)
|
llm = LLM(**engine_args)
|
||||||
|
|
||||||
# To maintain code compatibility in this script, we add LoRA here.
|
|
||||||
# You can also add LoRA using:
|
|
||||||
# llm.generate(prompts, lora_request=lora_request,...)
|
|
||||||
if req_data.lora_requests:
|
|
||||||
for lora_request in req_data.lora_requests:
|
|
||||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
|
||||||
|
|
||||||
# We set temperature to 0.2 so that outputs can be different
|
# We set temperature to 0.2 so that outputs can be different
|
||||||
# even when all prompts are identical when running batch inference.
|
# even when all prompts are identical when running batch inference.
|
||||||
sampling_params = SamplingParams(temperature=0.2,
|
sampling_params = SamplingParams(temperature=0.2,
|
||||||
@ -226,8 +219,15 @@ def main(args):
|
|||||||
if args.num_prompts > 1:
|
if args.num_prompts > 1:
|
||||||
# Batch inference
|
# Batch inference
|
||||||
inputs = [inputs] * args.num_prompts
|
inputs = [inputs] * args.num_prompts
|
||||||
|
# Add LoRA request if applicable
|
||||||
|
lora_request = (req_data.lora_requests *
|
||||||
|
args.num_prompts if req_data.lora_requests else None)
|
||||||
|
|
||||||
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
outputs = llm.generate(
|
||||||
|
inputs,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
)
|
||||||
|
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
generated_text = o.outputs[0].text
|
generated_text = o.outputs[0].text
|
||||||
|
|||||||
@ -8,6 +8,7 @@ on HuggingFace model repository.
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import NamedTuple, Optional
|
from typing import NamedTuple, Optional
|
||||||
|
|
||||||
@ -1055,6 +1056,20 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
|
|||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def time_counter(enable: bool):
|
||||||
|
if enable:
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
yield
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
print("-" * 50)
|
||||||
|
print("-- generate time = {}".format(elapsed_time))
|
||||||
|
print("-" * 50)
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
model = args.model_type
|
model = args.model_type
|
||||||
if model not in model_example_map:
|
if model not in model_example_map:
|
||||||
@ -1113,17 +1128,16 @@ def main(args):
|
|||||||
},
|
},
|
||||||
} for i in range(args.num_prompts)]
|
} for i in range(args.num_prompts)]
|
||||||
|
|
||||||
if args.time_generate:
|
# Add LoRA request if applicable
|
||||||
import time
|
lora_request = (req_data.lora_requests *
|
||||||
start_time = time.time()
|
args.num_prompts if req_data.lora_requests else None)
|
||||||
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
print("-" * 50)
|
|
||||||
print("-- generate time = {}".format(elapsed_time))
|
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
else:
|
with time_counter(args.time_generate):
|
||||||
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
outputs = llm.generate(
|
||||||
|
inputs,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
)
|
||||||
|
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
|
|||||||
@ -661,13 +661,6 @@ def run_generate(model, question: str, image_urls: list[str],
|
|||||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||||
llm = LLM(**engine_args)
|
llm = LLM(**engine_args)
|
||||||
|
|
||||||
# To maintain code compatibility in this script, we add LoRA here.
|
|
||||||
# You can also add LoRA using:
|
|
||||||
# llm.generate(prompts, lora_request=lora_request,...)
|
|
||||||
if req_data.lora_requests:
|
|
||||||
for lora_request in req_data.lora_requests:
|
|
||||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(temperature=0.0,
|
||||||
max_tokens=256,
|
max_tokens=256,
|
||||||
stop_token_ids=req_data.stop_token_ids)
|
stop_token_ids=req_data.stop_token_ids)
|
||||||
@ -679,7 +672,9 @@ def run_generate(model, question: str, image_urls: list[str],
|
|||||||
"image": req_data.image_data
|
"image": req_data.image_data
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params,
|
||||||
|
lora_request=req_data.lora_requests,
|
||||||
|
)
|
||||||
|
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
@ -724,6 +719,7 @@ def run_chat(model: str, question: str, image_urls: list[str],
|
|||||||
}],
|
}],
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
chat_template=req_data.chat_template,
|
chat_template=req_data.chat_template,
|
||||||
|
lora_request=req_data.lora_requests,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|||||||
@ -433,8 +433,8 @@ VLM_TEST_SETTINGS = {
|
|||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
task="generate",
|
task="generate",
|
||||||
# use eager mode for hf runner since phi3v didn't work with flash_attn
|
# use sdpa mode for hf runner since phi3v didn't work with flash_attn
|
||||||
hf_model_kwargs={"_attn_implementation": "eager"},
|
hf_model_kwargs={"_attn_implementation": "sdpa"},
|
||||||
use_tokenizer_eos=True,
|
use_tokenizer_eos=True,
|
||||||
vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output,
|
vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output,
|
||||||
num_logprobs=10,
|
num_logprobs=10,
|
||||||
|
|||||||
@ -2,18 +2,22 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import librosa
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal.image import rescale_image_size
|
from vllm.multimodal.image import rescale_image_size
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
|
||||||
|
PromptImageInput, VllmRunner)
|
||||||
from ....utils import large_gpu_test
|
from ....utils import large_gpu_test
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
@ -29,6 +33,8 @@ model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
|
|||||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||||
# we have to manually specify the path of the lora weights.
|
# we have to manually specify the path of the lora weights.
|
||||||
vision_lora_path = os.path.join(model_path, "vision-lora")
|
vision_lora_path = os.path.join(model_path, "vision-lora")
|
||||||
|
speech_question = os.path.join(model_path, "examples",
|
||||||
|
"what_is_shown_in_this_image.wav")
|
||||||
models = [model_path]
|
models = [model_path]
|
||||||
|
|
||||||
|
|
||||||
@ -64,7 +70,8 @@ if current_platform.is_rocm():
|
|||||||
def run_test(
|
def run_test(
|
||||||
hf_runner: type[HfRunner],
|
hf_runner: type[HfRunner],
|
||||||
vllm_runner: type[VllmRunner],
|
vllm_runner: type[VllmRunner],
|
||||||
inputs: list[tuple[list[str], PromptImageInput]],
|
inputs: Sequence[tuple[list[str], PromptImageInput,
|
||||||
|
Optional[PromptAudioInput]]],
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
*,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
@ -104,28 +111,49 @@ def run_test(
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
lora_request = LoRARequest("vision", 1, vision_lora_path)
|
lora_request = LoRARequest("vision", 1, vision_lora_path)
|
||||||
vllm_model.model.llm_engine.add_lora(lora_request=lora_request)
|
|
||||||
vllm_outputs_per_case = [
|
vllm_outputs_per_case = [
|
||||||
vllm_model.generate_greedy_logprobs(prompts,
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
images=images)
|
images=images,
|
||||||
for prompts, images in inputs
|
audios=audios,
|
||||||
|
lora_request=lora_request)
|
||||||
|
for prompts, images, audios in inputs
|
||||||
]
|
]
|
||||||
|
|
||||||
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
|
hf_model_kwargs = {"_attn_implementation": "sdpa"}
|
||||||
hf_model_kwargs = {"_attn_implementation": "eager"}
|
|
||||||
with hf_runner(model, dtype=dtype,
|
with hf_runner(model, dtype=dtype,
|
||||||
model_kwargs=hf_model_kwargs) as hf_model:
|
model_kwargs=hf_model_kwargs) as hf_model:
|
||||||
eos_token_id = hf_model.processor.tokenizer.eos_token_id
|
|
||||||
|
hf_processor = hf_model.processor
|
||||||
|
eos_token_id = hf_processor.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
def patch_hf_processor(*args,
|
||||||
|
text="",
|
||||||
|
images=None,
|
||||||
|
audio=None,
|
||||||
|
sampling_rate=None,
|
||||||
|
**kwargs):
|
||||||
|
audios = None
|
||||||
|
if audio is not None and sampling_rate is not None:
|
||||||
|
audios = [(audio, sampling_rate)]
|
||||||
|
return hf_processor(*args,
|
||||||
|
text=text,
|
||||||
|
images=images,
|
||||||
|
audios=audios,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
hf_model.processor = patch_hf_processor
|
||||||
|
|
||||||
hf_outputs_per_case = [
|
hf_outputs_per_case = [
|
||||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
images=images,
|
images=images,
|
||||||
|
audios=audios,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
num_logits_to_keep=0)
|
num_logits_to_keep=0)
|
||||||
for prompts, images in inputs
|
for prompts, images, audios in inputs
|
||||||
]
|
]
|
||||||
|
|
||||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||||
@ -138,8 +166,6 @@ def run_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Since we use _attn_implementation="eager" for hf_runner, there is more
|
|
||||||
# significant numerical difference. The basic `logprobs=5` fails to pass.
|
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"size_factors",
|
"size_factors",
|
||||||
@ -151,7 +177,7 @@ def run_test(
|
|||||||
# Single-scale, batched
|
# Single-scale, batched
|
||||||
[1.0, 1.0, 1.0],
|
[1.0, 1.0, 1.0],
|
||||||
# Multi-scale
|
# Multi-scale
|
||||||
[0.7, 0.75, 1.0],
|
[0.25, 0.5, 1.0],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
@ -166,6 +192,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
|||||||
inputs_per_image = [(
|
inputs_per_image = [(
|
||||||
[prompt for _ in size_factors],
|
[prompt for _ in size_factors],
|
||||||
[rescale_image_size(image, factor) for factor in size_factors],
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
None,
|
||||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||||
|
|
||||||
run_test(
|
run_test(
|
||||||
@ -201,17 +228,18 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
|||||||
@pytest.mark.parametrize("max_model_len", [10000])
|
@pytest.mark.parametrize("max_model_len", [10000])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [10])
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
@pytest.mark.xfail(
|
|
||||||
reason="Phi-4-MM multi-image inference is divergent with hf model.")
|
|
||||||
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||||
size_factors, dtype: str, max_model_len: int,
|
size_factors, dtype: str, max_model_len: int,
|
||||||
max_tokens: int, num_logprobs: int) -> None:
|
max_tokens: int, num_logprobs: int) -> None:
|
||||||
images = [asset.pil_image for asset in image_assets]
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
inputs_per_case = [
|
inputs_per_case = [
|
||||||
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
(
|
||||||
[[rescale_image_size(image, factor) for image in images]
|
[HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||||
for factor in size_factors])
|
[[rescale_image_size(image, factor) for image in images]
|
||||||
|
for factor in size_factors],
|
||||||
|
None,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
run_test(
|
run_test(
|
||||||
@ -226,3 +254,38 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
|||||||
mm_limit=2,
|
mm_limit=2,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
|
@pytest.mark.parametrize("max_model_len", [10000])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
|
def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
|
||||||
|
max_model_len: int, max_tokens: int,
|
||||||
|
num_logprobs: int) -> None:
|
||||||
|
|
||||||
|
# use the example speech question so that the model outputs are reasonable
|
||||||
|
audio = librosa.load(speech_question, sr=None)
|
||||||
|
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
|
||||||
|
|
||||||
|
inputs_vision_speech = [
|
||||||
|
(
|
||||||
|
["<|user|><|image_1|><|audio_1|><|end|><|assistant|>"],
|
||||||
|
[image],
|
||||||
|
[audio],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
inputs_vision_speech,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
mm_limit=1,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user