diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index f33efbab955e..9d7585914f5e 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -199,13 +199,6 @@ def main(args): engine_args = asdict(req_data.engine_args) | {"seed": args.seed} llm = LLM(**engine_args) - # To maintain code compatibility in this script, we add LoRA here. - # You can also add LoRA using: - # llm.generate(prompts, lora_request=lora_request,...) - if req_data.lora_requests: - for lora_request in req_data.lora_requests: - llm.llm_engine.add_lora(lora_request=lora_request) - # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. sampling_params = SamplingParams(temperature=0.2, @@ -226,8 +219,15 @@ def main(args): if args.num_prompts > 1: # Batch inference inputs = [inputs] * args.num_prompts + # Add LoRA request if applicable + lora_request = (req_data.lora_requests * + args.num_prompts if req_data.lora_requests else None) - outputs = llm.generate(inputs, sampling_params=sampling_params) + outputs = llm.generate( + inputs, + sampling_params=sampling_params, + lora_request=lora_request, + ) for o in outputs: generated_text = o.outputs[0].text diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 7eb2e852f266..7b587f29b5a7 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -8,6 +8,7 @@ on HuggingFace model repository. """ import os import random +from contextlib import contextmanager from dataclasses import asdict from typing import NamedTuple, Optional @@ -1055,6 +1056,20 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data, return inputs +@contextmanager +def time_counter(enable: bool): + if enable: + import time + start_time = time.time() + yield + elapsed_time = time.time() - start_time + print("-" * 50) + print("-- generate time = {}".format(elapsed_time)) + print("-" * 50) + else: + yield + + def main(args): model = args.model_type if model not in model_example_map: @@ -1113,17 +1128,16 @@ def main(args): }, } for i in range(args.num_prompts)] - if args.time_generate: - import time - start_time = time.time() - outputs = llm.generate(inputs, sampling_params=sampling_params) - elapsed_time = time.time() - start_time - print("-" * 50) - print("-- generate time = {}".format(elapsed_time)) - print("-" * 50) + # Add LoRA request if applicable + lora_request = (req_data.lora_requests * + args.num_prompts if req_data.lora_requests else None) - else: - outputs = llm.generate(inputs, sampling_params=sampling_params) + with time_counter(args.time_generate): + outputs = llm.generate( + inputs, + sampling_params=sampling_params, + lora_request=lora_request, + ) print("-" * 50) for o in outputs: diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index d9f84d2feae8..1ac141d8a583 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -661,13 +661,6 @@ def run_generate(model, question: str, image_urls: list[str], engine_args = asdict(req_data.engine_args) | {"seed": args.seed} llm = LLM(**engine_args) - # To maintain code compatibility in this script, we add LoRA here. - # You can also add LoRA using: - # llm.generate(prompts, lora_request=lora_request,...) - if req_data.lora_requests: - for lora_request in req_data.lora_requests: - llm.llm_engine.add_lora(lora_request=lora_request) - sampling_params = SamplingParams(temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids) @@ -679,7 +672,9 @@ def run_generate(model, question: str, image_urls: list[str], "image": req_data.image_data }, }, - sampling_params=sampling_params) + sampling_params=sampling_params, + lora_request=req_data.lora_requests, + ) print("-" * 50) for o in outputs: @@ -724,6 +719,7 @@ def run_chat(model: str, question: str, image_urls: list[str], }], sampling_params=sampling_params, chat_template=req_data.chat_template, + lora_request=req_data.lora_requests, ) print("-" * 50) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index e19cc241054b..5bd10544d81b 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -433,8 +433,8 @@ VLM_TEST_SETTINGS = { max_model_len=4096, max_num_seqs=2, task="generate", - # use eager mode for hf runner since phi3v didn't work with flash_attn - hf_model_kwargs={"_attn_implementation": "eager"}, + # use sdpa mode for hf runner since phi3v didn't work with flash_attn + hf_model_kwargs={"_attn_implementation": "sdpa"}, use_tokenizer_eos=True, vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output, num_logprobs=10, diff --git a/tests/models/decoder_only/vision_language/test_phi4mm.py b/tests/models/decoder_only/vision_language/test_phi4mm.py index c3e88b60978a..3cd830015076 100644 --- a/tests/models/decoder_only/vision_language/test_phi4mm.py +++ b/tests/models/decoder_only/vision_language/test_phi4mm.py @@ -2,18 +2,22 @@ import os import re +from collections.abc import Sequence from typing import Optional +import librosa import pytest from huggingface_hub import snapshot_download from transformers import AutoTokenizer +from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest from vllm.multimodal.image import rescale_image_size from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs -from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput, + PromptImageInput, VllmRunner) from ....utils import large_gpu_test from ...utils import check_logprobs_close @@ -29,6 +33,8 @@ model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct") # Since the vision-lora and speech-lora co-exist with the base model, # we have to manually specify the path of the lora weights. vision_lora_path = os.path.join(model_path, "vision-lora") +speech_question = os.path.join(model_path, "examples", + "what_is_shown_in_this_image.wav") models = [model_path] @@ -64,7 +70,8 @@ if current_platform.is_rocm(): def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - inputs: list[tuple[list[str], PromptImageInput]], + inputs: Sequence[tuple[list[str], PromptImageInput, + Optional[PromptAudioInput]]], model: str, *, max_model_len: int, @@ -104,28 +111,49 @@ def run_test( enforce_eager=True, ) as vllm_model: lora_request = LoRARequest("vision", 1, vision_lora_path) - vllm_model.model.llm_engine.add_lora(lora_request=lora_request) vllm_outputs_per_case = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs + images=images, + audios=audios, + lora_request=lora_request) + for prompts, images, audios in inputs ] - # use eager mode for hf runner, since phi3_v didn't work with flash_attn - hf_model_kwargs = {"_attn_implementation": "eager"} + hf_model_kwargs = {"_attn_implementation": "sdpa"} with hf_runner(model, dtype=dtype, model_kwargs=hf_model_kwargs) as hf_model: - eos_token_id = hf_model.processor.tokenizer.eos_token_id + + hf_processor = hf_model.processor + eos_token_id = hf_processor.tokenizer.eos_token_id + + def patch_hf_processor(*args, + text="", + images=None, + audio=None, + sampling_rate=None, + **kwargs): + audios = None + if audio is not None and sampling_rate is not None: + audios = [(audio, sampling_rate)] + return hf_processor(*args, + text=text, + images=images, + audios=audios, + **kwargs) + + hf_model.processor = patch_hf_processor + hf_outputs_per_case = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, num_logprobs=num_logprobs, images=images, + audios=audios, eos_token_id=eos_token_id, num_logits_to_keep=0) - for prompts, images in inputs + for prompts, images, audios in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, @@ -138,8 +166,6 @@ def run_test( ) -# Since we use _attn_implementation="eager" for hf_runner, there is more -# significant numerical difference. The basic `logprobs=5` fails to pass. @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "size_factors", @@ -151,7 +177,7 @@ def run_test( # Single-scale, batched [1.0, 1.0, 1.0], # Multi-scale - [0.7, 0.75, 1.0], + [0.25, 0.5, 1.0], ], ) @pytest.mark.parametrize("dtype", [target_dtype]) @@ -166,6 +192,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, inputs_per_image = [( [prompt for _ in size_factors], [rescale_image_size(image, factor) for factor in size_factors], + None, ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] run_test( @@ -201,17 +228,18 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, @pytest.mark.parametrize("max_model_len", [10000]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -@pytest.mark.xfail( - reason="Phi-4-MM multi-image inference is divergent with hf model.") def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype: str, max_model_len: int, max_tokens: int, num_logprobs: int) -> None: images = [asset.pil_image for asset in image_assets] inputs_per_case = [ - ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors]) + ( + [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors], + None, + ), ] run_test( @@ -226,3 +254,38 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, mm_limit=2, tensor_parallel_size=1, ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_model_len", [10000]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, + max_model_len: int, max_tokens: int, + num_logprobs: int) -> None: + + # use the example speech question so that the model outputs are reasonable + audio = librosa.load(speech_question, sr=None) + image = ImageAsset("cherry_blossom").pil_image.convert("RGB") + + inputs_vision_speech = [ + ( + ["<|user|><|image_1|><|audio_1|><|end|><|assistant|>"], + [image], + [audio], + ), + ] + + run_test( + hf_runner, + vllm_runner, + inputs_vision_speech, + model, + dtype=dtype, + max_model_len=max_model_len, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + )