diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f091d191c6e92..cdc487c1e78d4 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -429,7 +429,7 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' + - pytest -v -s models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s models/embedding/vision_language -m core_model - pytest -v -s models/encoder_decoder/audio_language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model @@ -448,10 +448,7 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model' - # HACK - run phi3v tests separately to sidestep this transformers bug - # https://github.com/huggingface/transformers/issues/34307 - - pytest -v -s models/decoder_only/vision_language/test_phi3v.py - - pytest -v -s --ignore models/decoder_only/vision_language/test_models.py --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' + - pytest -v -s --ignore models/decoder_only/vision_language/test_models.py models/decoder_only/vision_language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/vision_language -m 'not core_model' - pytest -v -s models/encoder_decoder/language -m 'not core_model' - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 64c841ced8f47..e19cc241054bc 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -425,23 +425,20 @@ VLM_TEST_SETTINGS = { max_num_seqs=2, patch_hf_runner=model_utils.molmo_patch_hf_runner, ), - # Tests for phi3v currently live in another file because of a bug in - # transformers. Once this issue is fixed, we can enable them here instead. - # https://github.com/huggingface/transformers/issues/34307 - # "phi3v": VLMTestInfo( - # models=["microsoft/Phi-3.5-vision-instruct"], - # test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - # prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|end|>\n<|assistant|>\n", # noqa: E501 - # img_idx_to_prompt=lambda idx: f"<|image_{idx}|>\n", - # max_model_len=4096, - # max_num_seqs=2, - # task="generate", - # # use eager mode for hf runner since phi3v didn't work with flash_attn - # hf_model_kwargs={"_attn_implementation": "eager"}, - # use_tokenizer_eos=True, - # vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output, - # num_logprobs=10, - # ), + "phi3v": VLMTestInfo( + models=["microsoft/Phi-3.5-vision-instruct"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|end|>\n<|assistant|>\n", # noqa: E501 + img_idx_to_prompt=lambda idx: f"<|image_{idx}|>\n", + max_model_len=4096, + max_num_seqs=2, + task="generate", + # use eager mode for hf runner since phi3v didn't work with flash_attn + hf_model_kwargs={"_attn_implementation": "eager"}, + use_tokenizer_eos=True, + vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output, + num_logprobs=10, + ), "pixtral_hf": VLMTestInfo( models=["nm-testing/pixtral-12b-FP8-dynamic"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py deleted file mode 100644 index 237d499d8f6ad..0000000000000 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ /dev/null @@ -1,245 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import os -import re -from typing import Optional - -import pytest -from packaging.version import Version -from transformers import AutoTokenizer -from transformers import __version__ as TRANSFORMERS_VERSION - -from vllm.multimodal.image import rescale_image_size -from vllm.platforms import current_platform -from vllm.sequence import SampleLogprobs - -from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner -from ...utils import check_logprobs_close - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 - "cherry_blossom": - "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", -}) -HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 - -models = ["microsoft/Phi-3.5-vision-instruct"] - - -def vllm_to_hf_output(vllm_output: tuple[list[int], str, - Optional[SampleLogprobs]], - model: str): - """Sanitize vllm output to be comparable with hf output.""" - _, output_str, out_logprobs = vllm_output - - output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str) - assert output_str_without_image[0] == " " - output_str_without_image = output_str_without_image[1:] - - hf_output_str = output_str_without_image + "<|end|><|endoftext|>" - - tokenizer = AutoTokenizer.from_pretrained(model) - hf_output_ids = tokenizer.encode(output_str_without_image) - assert hf_output_ids[0] == 1 - hf_output_ids = hf_output_ids[1:] - - return hf_output_ids, hf_output_str, out_logprobs - - -target_dtype = "half" - -# ROCm Triton FA can run into shared memory issues with these models, -# use other backends in the meantime -# FIXME (mattwong, gshtrasb, hongxiayan) -if current_platform.is_rocm(): - os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - inputs: list[tuple[list[str], PromptImageInput]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - mm_limit: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - # HACK - this is an attempted workaround for the following bug - # https://github.com/huggingface/transformers/issues/34307 - from transformers import AutoImageProcessor # noqa: F401 - from transformers import AutoProcessor # noqa: F401 - - # Once the model repo is updated to 4.49, we should be able to run the - # test in `test_models.py` without the above workaround - if Version(TRANSFORMERS_VERSION) >= Version("4.49"): - pytest.skip(f"`transformers=={TRANSFORMERS_VERSION}` installed, " - "but `transformers<=4.49` is required to run this model. " - "Reason: Cannot run HF implementation") - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - # max_model_len should be greater than image_feature_size - with vllm_runner(model, - task="generate", - max_model_len=4096, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs - ] - - # use eager mode for hf runner, since phi3_v didn't work with flash_attn - hf_model_kwargs = {"_attn_implementation": "eager"} - with hf_runner(model, dtype=dtype, - model_kwargs=hf_model_kwargs) as hf_model: - eos_token_id = hf_model.processor.tokenizer.eos_token_id - hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - eos_token_id=eos_token_id) - for prompts, images in inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - -# Since we use _attn_implementation="eager" for hf_runner, there is more -# significant numerical difference. The basic `logprobs=5` fails to pass. -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_tokens: int, num_logprobs: int) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - - run_test( - hf_runner, - vllm_runner, - inputs_per_image, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=1, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", [target_dtype]) -def test_regression_7840(hf_runner, vllm_runner, image_assets, model, - dtype) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_regresion_7840 = [ - ([prompt], [image]) for image, prompt in zip(images, HF_IMAGE_PROMPTS) - ] - - # Regression test for #7840. - run_test( - hf_runner, - vllm_runner, - inputs_regresion_7840, - model, - dtype=dtype, - max_tokens=128, - num_logprobs=10, - mm_limit=1, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, - size_factors, dtype: str, max_tokens: int, - num_logprobs: int) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_case = [ - ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors]) - ] - - run_test( - hf_runner, - vllm_runner, - inputs_per_case, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=2, - tensor_parallel_size=1, - ) diff --git a/tests/models/registry.py b/tests/models/registry.py index 40479fb8a5b07..b43bdb9c788da 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -326,7 +326,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501 "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", max_transformers_version="4.48", - transformers_version_reason="Use of private method which no longer exists.", # noqa: E501 + transformers_version_reason="Incorrectly-detected `tensorflow` import.", # noqa: E501 extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501 trust_remote_code=True), "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", @@ -335,6 +335,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, + max_transformers_version="4.48", + transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501 "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", trust_remote_code=True), @@ -351,8 +353,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 - trust_remote_code=True, - max_transformers_version="4.50"), + trust_remote_code=True), # [Encoder-decoder] # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model