diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 3e8b2f89642c4..ef7e77fa3ec61 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -883,8 +883,7 @@ For more details, please see: ::: :::{note} -The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingface.co/mistral-community/pixtral-12b/discussions/22)). -A corrected version is available at . +`mistral-community/pixtral-12b` does not support V1 yet. ::: :::{note} diff --git a/examples/template_pixtral_hf.jinja b/examples/template_pixtral_hf.jinja deleted file mode 100644 index e94661cb39071..0000000000000 --- a/examples/template_pixtral_hf.jinja +++ /dev/null @@ -1,38 +0,0 @@ -{%- if messages[0]["role"] == "system" %} - {%- set system_message = messages[0]["content"] %} - {%- set loop_messages = messages[1:] %} -{%- else %} - {%- set loop_messages = messages %} -{%- endif %} - -{{- bos_token }} -{%- for message in loop_messages %} - {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} - {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }} - {%- endif %} - {%- if message["role"] == "user" %} - {%- if loop.last and system_message is defined %} - {{- "[INST]" + system_message + "\n" }} - {%- else %} - {{- "[INST]" }} - {%- endif %} - {%- if message["content"] is not string %} - {%- for chunk in message["content"] %} - {%- if chunk["type"] == "text" %} - {{- chunk["text"] }} - {%- elif chunk["type"] == "image" %} - {{- "[IMG]" }} - {%- else %} - {{- raise_exception("Unrecognized content type!") }} - {%- endif %} - {%- endfor %} - {%- else %} - {{- message["content"] }} - {%- endif %} - {{- "[/INST]" }} - {%- elif message["role"] == "assistant" %} - {{- message["content"] + eos_token}} - {%- else %} - {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} - {%- endif %} -{%- endfor %} diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 737f733092b6d..5c469007af23e 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -761,7 +761,6 @@ def test_resolve_content_format_hf_defined(model, expected_format): ("template_falcon.jinja", "string"), ("template_inkbot.jinja", "string"), ("template_llava.jinja", "string"), - ("template_pixtral_hf.jinja", "openai"), ("template_vlm2vec.jinja", "openai"), ("tool_chat_template_granite_20b_fc.jinja", "string"), ("tool_chat_template_hermes.jinja", "string"), diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 95505dcf5c29f..b00ec6fa69995 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -224,7 +224,7 @@ VLM_TEST_SETTINGS = { marks=[ pytest.mark.skipif( Version(TRANSFORMERS_VERSION) >= Version("4.48"), - reason="HF model is not compatible with transformers>=4.48.0", + reason="HF model is not compatible with transformers>=4.48", ) ], ), @@ -359,7 +359,7 @@ VLM_TEST_SETTINGS = { marks=[ pytest.mark.skipif( Version(TRANSFORMERS_VERSION) >= Version("4.48"), - reason="HF model is not compatible with transformers>=4.48.0", + reason="HF model is not compatible with transformers>=4.48", ) ], ), diff --git a/tests/models/embedding/vision_language/test_llava_next.py b/tests/models/embedding/vision_language/test_llava_next.py index 6ba3c5403896c..990c6c150fcdc 100644 --- a/tests/models/embedding/vision_language/test_llava_next.py +++ b/tests/models/embedding/vision_language/test_llava_next.py @@ -4,7 +4,6 @@ from typing import List, Type import pytest import torch.nn.functional as F -import transformers from transformers import AutoModelForVision2Seq from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner @@ -57,6 +56,10 @@ def _run_test( with hf_runner(model, dtype=dtype, auto_cls=AutoModelForVision2Seq) as hf_model: + # Patch the issue where generation_config.json is missing + hf_model.processor.patch_size = \ + hf_model.model.config.vision_config.patch_size + # Patch the issue where image_token_id # exceeds the maximum allowed vocab size hf_model.model.resize_token_embeddings( @@ -88,8 +91,6 @@ def _run_test( ) -@pytest.mark.skipif(transformers.__version__ >= "4.46", - reason="Model broken with changes in transformers 4.46") @pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 63d308ef6d191..b1fee3eeb542f 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -293,16 +293,29 @@ class PixtralHFMultiModalProcessor( pixel_values = processed_outputs.get("pixel_values") if pixel_values is not None: - images = mm_data["images"] - assert isinstance(images, list) + # Before/after https://github.com/huggingface/transformers/pull/35122 + if Version(TRANSFORMERS_VERSION) <= Version("4.48.2"): + images = mm_data["images"] + assert isinstance(images, list) - # Original output: (1, num_images, C, H, W) - # New output: (num_images, C, H, W) - assert (isinstance(pixel_values, list) and len(pixel_values) == 1) - assert (isinstance(pixel_values[0], list) - and len(pixel_values[0]) == len(images)) + # Original output: (1, num_images, C, H, W) + # New output: (num_images, C, H, W) + assert (isinstance(pixel_values, list) + and len(pixel_values) == 1) + assert (isinstance(pixel_values[0], list) + and len(pixel_values[0]) == len(images)) - processed_outputs["pixel_values"] = pixel_values[0] + processed_outputs["pixel_values"] = pixel_values[0] + else: + # Avoid padding since we need the output for each image to be + # independent of other images for the cache to work correctly + image_sizes = processed_outputs["image_sizes"] + assert len(pixel_values) == len(image_sizes) + + processed_outputs["pixel_values"] = [ + p[:, :h, :w] + for p, (h, w) in zip(pixel_values, image_sizes) + ] return processed_outputs diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index defdeb54afb6a..719916642f25c 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -73,7 +73,15 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): return self.ctx.get_hf_config(LlavaNextConfig) def get_hf_processor(self): - return self.ctx.get_hf_processor(LlavaNextProcessor) + hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor) + + # In case patch_size is omitted from `processor_config.json` + # e.g. for E5-V: https://huggingface.co/royokong/e5-v + if hf_processor.patch_size is None: + patch_size = self.get_vision_encoder_info().get_patch_size() + hf_processor.patch_size = patch_size + + return hf_processor # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113 def get_num_image_tokens( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 20f3a3d1989be..58a4448d436aa 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -342,6 +342,15 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): **kwargs: object, ): hf_processor = self.ctx.get_hf_processor() + + # NumPy arrays are considered as Iterable but not Sequence in + # https://github.com/huggingface/transformers/blob/main/src/transformers/image_transforms.py#L428 + image_processor = hf_processor.image_processor # type: ignore + for attr in ("mean", "std"): + val = getattr(image_processor, attr) + if isinstance(val, np.ndarray): + setattr(image_processor, attr, val.tolist()) + return hf_processor def get_image_processor(self): diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 2f2535f368cff..5f9593ee8b205 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -141,9 +141,9 @@ Uses a list instead of a tensor if the dimensions of each element do not match. def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: """Equality check between :data:`NestedTensors` objects.""" if isinstance(a, torch.Tensor): - return isinstance(b, torch.Tensor) and bool((a == b).all().item()) + return isinstance(b, torch.Tensor) and torch.equal(a, b) elif isinstance(b, torch.Tensor): - return isinstance(a, torch.Tensor) and bool((b == a).all().item()) + return isinstance(a, torch.Tensor) and torch.equal(b, a) if isinstance(a, list): return (isinstance(b, list)