diff --git a/requirements/test.in b/requirements/test.in index 9b574a09fcce..bbbd41e168a6 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -34,7 +34,7 @@ opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test mteb>=1.38.11, <2 # required for mteb test -transformers==4.51.3 +transformers==4.52.4 tokenizers==0.21.1 huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. schemathesis>=3.39.15 # Required for openai schema test. diff --git a/requirements/test.txt b/requirements/test.txt index 03aec80ac128..fb0eede080ff 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -794,7 +794,7 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.51.3 +transformers==4.52.4 # via # -r requirements/test.in # genai-perf diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index a5bbcfc22e9c..496850b19af4 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -226,6 +226,8 @@ VLM_TEST_SETTINGS = { img_idx_to_prompt=lambda idx: "", auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.blip2_vllm_to_hf_output, + # FIXME: https://github.com/huggingface/transformers/pull/38510 + marks=[pytest.mark.skip("Model is broken")], ), "chameleon": VLMTestInfo( models=["facebook/chameleon-7b"], @@ -281,10 +283,10 @@ VLM_TEST_SETTINGS = { multi_image_prompt="Describe the two images in detail.", # noqa: E501 max_model_len=4096, max_num_seqs=2, - dtype="bfloat16", auto_cls=AutoModelForImageTextToText, vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}}, patch_hf_runner=model_utils.gemma3_patch_hf_runner, + num_logprobs=10, ), "glm4v": VLMTestInfo( models=["THUDM/glm-4v-9b"], @@ -337,7 +339,8 @@ VLM_TEST_SETTINGS = { models=[ "OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-2B", - "OpenGVLab/Mono-InternVL-2B", + # FIXME: Config cannot be loaded in transformers 4.52 + # "OpenGVLab/Mono-InternVL-2B", ], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 @@ -568,6 +571,8 @@ VLM_TEST_SETTINGS = { max_num_seqs=2, vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output, prompt_path_encoder=model_utils.qwen_prompt_path_encoder, + # FIXME: https://github.com/huggingface/transformers/issues/38358 + marks=[pytest.mark.skip("Model initialization fails")], ), "qwen2_vl": VLMTestInfo( models=["Qwen/Qwen2-VL-2B-Instruct"], diff --git a/tests/models/multimodal/generation/test_florence2.py b/tests/models/multimodal/generation/test_florence2.py index b048cec5e5e0..a622957f96f6 100644 --- a/tests/models/multimodal/generation/test_florence2.py +++ b/tests/models/multimodal/generation/test_florence2.py @@ -100,6 +100,8 @@ def run_test( ) +# FIXME: https://github.com/huggingface/transformers/issues/38358 +@pytest.mark.skip("Model initialization fails") @pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize( diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index 14552010d376..c5ffa5f3a70a 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -29,7 +29,7 @@ def vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs -MODEL_NAME = "ibm-granite/granite-speech-3.3-8b" +MODEL_NAME = "ibm-granite/granite-speech-3.3-2b" # Audio lora co-exists directly in the model directory, but # currently still needs to be passed directly to vLLM. audio_lora_path = MODEL_NAME diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index e4cd476a96b1..4e8465778e25 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -122,6 +122,10 @@ def run_test( for prompts, images, audios in inputs ] + # This error occurs inside `get_peft_model` + # FIXME: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/75 + pytest.skip("HF impl is not compatible with current transformers") + hf_model_kwargs = {"_attn_implementation": "sdpa"} with hf_runner(model, dtype=dtype, model_kwargs=hf_model_kwargs) as hf_model: diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 1b087191f636..af4c72f44b67 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -10,11 +10,12 @@ from typing import Optional, Union import numpy as np import numpy.typing as npt +import pytest import regex as re import torch from PIL.Image import Image from transformers import (AutoConfig, AutoTokenizer, BatchFeature, - GenerationConfig) + GenerationConfig, GenerationMixin) from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side @@ -324,6 +325,16 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner: hf_model.processor = processor + orig_generate = hf_model.model.generate + + def _generate(self, *args, **kwargs): + # FIXME: https://github.com/huggingface/transformers/issues/38333 + kwargs["disable_compile"] = True + + return orig_generate(*args, **kwargs) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) + return hf_model @@ -610,6 +621,11 @@ def _internvl_generate( if getattr(self, "use_visual_token_mask", False): visual_token_mask = selected.reshape(B, N, 1).to(input_embeds.dtype) forward_kwargs["visual_token_mask"] = visual_token_mask + + # e.g. InternVL2-2B + if not isinstance(self.language_model, GenerationMixin): + pytest.skip("HF impl is not compatible with current transformers") + outputs = self.language_model.generate( **forward_kwargs, **generate_kwargs, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index be574435e099..1e6608955b31 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -245,7 +245,7 @@ def _test_processing_correctness_one( "adept/fuyu-8b", "google/gemma-3-4b-it", "THUDM/glm-4v-9b", - "ibm-granite/granite-speech-3.3-8b", + "ibm-granite/granite-speech-3.3-2b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL3-1B", diff --git a/tests/models/registry.py b/tests/models/registry.py index ed49676a9f5d..3e07dc0f322e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -160,17 +160,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct", - is_available_online=False, - min_transformers_version="4.52.2"), + min_transformers_version="4.53"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), - "Glm4ForCausalLM": _HfExamplesInfo( - "THUDM/GLM-4-32B-0414", - is_available_online=False, - min_transformers_version="4.52.dev0" - ), + "Glm4ForCausalLM": _HfExamplesInfo("THUDM/GLM-4-9B-0414"), "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder", @@ -181,8 +176,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview", # noqa: E501 - min_transformers_version="4.52.0"), # noqa: E501 + "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), @@ -203,8 +197,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", is_available_online=False), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), - "Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1", - is_available_online=False), + "Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"), "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501 "MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16", trust_remote_code=True), @@ -243,10 +236,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), - "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b", - is_available_online=False), + "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501 - is_available_online=False), + v0_only=True), "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t", v0_only=True), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), @@ -256,7 +248,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "TeleFLMForCausalLM": _HfExamplesInfo("CofeAI/FLM-2-52B-Instruct-2407", trust_remote_code=True), "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", - is_available_online=False, + tokenizer="meta-llama/Llama-2-7b", trust_remote_code=True), "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", @@ -275,8 +267,7 @@ _EMBEDDING_EXAMPLE_MODELS = { trust_remote_code=True), "GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True, - hf_overrides={"architectures": - ["GteNewModel"]}), + hf_overrides={"architectures": ["GteNewModel"]}), # noqa: E501 "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", trust_remote_code=True), "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 @@ -298,10 +289,8 @@ _EMBEDDING_EXAMPLE_MODELS = { "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", trust_remote_code=True), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 - # The model on Huggingface is currently being updated, - # hence I temporarily mark it as not available online - "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 - is_available_online=False), + "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 + is_available_online=False), # noqa: E501 } _CROSS_ENCODER_EXAMPLE_MODELS = { @@ -327,8 +316,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), - "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-8b", # noqa: E501 - min_transformers_version="4.52.0"), # noqa: E501 + "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501 "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 @@ -347,7 +335,6 @@ _MULTIMODAL_EXAMPLE_MODELS = { trust_remote_code=True, v0_only=True), "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 - min_transformers_version="4.51", max_model_len=10240), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501 @@ -360,8 +347,6 @@ _MULTIMODAL_EXAMPLE_MODELS = { transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501 "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", - max_transformers_version="4.48", - transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 trust_remote_code=True), "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501 @@ -399,10 +384,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501 - "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B", - min_transformers_version="4.52"), - "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ", # noqa: E501 - min_transformers_version="4.52"), + "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), + "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 @@ -413,8 +396,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 - tokenizer="Isotr0py/Florence-2-tokenizer", - trust_remote_code=True,), # noqa: E501 + tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501 + trust_remote_code=True), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 } diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index af023d903438..98a58d01e2a1 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -21,6 +21,10 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") + # FIXME: Possible memory leak in the previous tests? + if model_arch == "GraniteSpeechForConditionalGeneration": + pytest.skip("Avoid OOM") + # Avoid OOM and reduce initialization time by only using 1 layer def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update(model_info.hf_overrides) @@ -41,6 +45,13 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): "num_hidden_layers": 1, }) + # e.g.: ibm-granite/granite-speech-3.3-2b + if hasattr(hf_config, "encoder_config"): + hf_config.encoder_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + }) + return hf_config # Avoid calling model.forward() diff --git a/vllm/config.py b/vllm/config.py index f6ca9328b8a1..a07c41ddab19 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3139,6 +3139,8 @@ def _find_dtype( config_dtype = getattr(config.get_text_config(), "torch_dtype", None) if config_dtype is None and hasattr(config, "vision_config"): config_dtype = getattr(config.vision_config, "torch_dtype", None) + if config_dtype is None and hasattr(config, "encoder_config"): + config_dtype = getattr(config.encoder_config, "torch_dtype", None) # Try to read the dtype of the weights if they are in safetensors format if config_dtype is None: diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 22efb707af73..7e15e57a4d03 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -111,7 +111,13 @@ class AyaVisionProcessingInfo(BaseProcessingInfo): return self.ctx.get_hf_config(AyaVisionConfig) def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor: - return self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs) + processor = self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs) + + # Temporary workaround since this processor has multiple image tokens + # See https://github.com/huggingface/transformers/issues/38350 + processor._check_special_mm_tokens = lambda *args, **kwargs: None + + return processor def get_image_processor(self) -> GotOcr2ImageProcessor: return self.get_hf_processor().image_processor @@ -188,9 +194,7 @@ class AyaVisionMultiModalProcessor( image_processor = hf_processor.image_processor # HF processor pops the `num_patches` kwarg, which is needed by vLLM - if (images := - mm_data.get("images")) is not None and '' in prompt: - assert isinstance(images, list) + if (images := mm_data.get("images")) is not None: parsed_images = (self._get_data_parser().parse_mm_data({ "image": images diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 4bc5e2a0cfae..de8596282ca9 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -22,8 +22,8 @@ from typing import Literal, Optional, TypedDict, Union import torch from torch import nn -from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, - Idefics3Processor) +from transformers import (AddedToken, BatchFeature, Idefics3Config, + Idefics3ImageProcessor, Idefics3Processor) from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ReplicatedLinear @@ -199,13 +199,21 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): return grid_w * grid_h + 1 + # TODO: Remove after requiring transformers>=4.52 + def _get_content(self, token: Union[AddedToken, str]) -> str: + if isinstance(token, str): + return token + + return token.content + def _get_image_token( self, processor: Optional[Idefics3Processor]) -> tuple[str, str, str]: if processor is None: processor = self.get_hf_processor() - image_token = processor.image_token.content - fake_image_token = processor.fake_image_token.content + + image_token = self._get_content(processor.image_token) + fake_image_token = self._get_content(processor.fake_image_token) global_image_token = processor.global_image_tag return image_token, fake_image_token, global_image_token