[VLM] Update compatibility with transformers 4.49

This commit is contained in:
Cyrus Leung 2025-02-06 11:09:45 +08:00 committed by GitHub
parent bf3b79efb8
commit 75404d041b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 48 additions and 57 deletions

View File

@ -883,8 +883,7 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
:::
:::{note}
The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingface.co/mistral-community/pixtral-12b/discussions/22)).
A corrected version is available at <gh-file:examples/template_pixtral_hf.jinja>.
`mistral-community/pixtral-12b` does not support V1 yet.
:::
:::{note}

View File

@ -1,38 +0,0 @@
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{{- bos_token }}
{%- for message in loop_messages %}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}
{%- endif %}
{%- if message["role"] == "user" %}
{%- if loop.last and system_message is defined %}
{{- "[INST]" + system_message + "\n" }}
{%- else %}
{{- "[INST]" }}
{%- endif %}
{%- if message["content"] is not string %}
{%- for chunk in message["content"] %}
{%- if chunk["type"] == "text" %}
{{- chunk["text"] }}
{%- elif chunk["type"] == "image" %}
{{- "[IMG]" }}
{%- else %}
{{- raise_exception("Unrecognized content type!") }}
{%- endif %}
{%- endfor %}
{%- else %}
{{- message["content"] }}
{%- endif %}
{{- "[/INST]" }}
{%- elif message["role"] == "assistant" %}
{{- message["content"] + eos_token}}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}

View File

@ -761,7 +761,6 @@ def test_resolve_content_format_hf_defined(model, expected_format):
("template_falcon.jinja", "string"),
("template_inkbot.jinja", "string"),
("template_llava.jinja", "string"),
("template_pixtral_hf.jinja", "openai"),
("template_vlm2vec.jinja", "openai"),
("tool_chat_template_granite_20b_fc.jinja", "string"),
("tool_chat_template_hermes.jinja", "string"),

View File

@ -224,7 +224,7 @@ VLM_TEST_SETTINGS = {
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) >= Version("4.48"),
reason="HF model is not compatible with transformers>=4.48.0",
reason="HF model is not compatible with transformers>=4.48",
)
],
),
@ -359,7 +359,7 @@ VLM_TEST_SETTINGS = {
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) >= Version("4.48"),
reason="HF model is not compatible with transformers>=4.48.0",
reason="HF model is not compatible with transformers>=4.48",
)
],
),

View File

@ -4,7 +4,6 @@ from typing import List, Type
import pytest
import torch.nn.functional as F
import transformers
from transformers import AutoModelForVision2Seq
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
@ -57,6 +56,10 @@ def _run_test(
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
# Patch the issue where generation_config.json is missing
hf_model.processor.patch_size = \
hf_model.model.config.vision_config.patch_size
# Patch the issue where image_token_id
# exceeds the maximum allowed vocab size
hf_model.model.resize_token_embeddings(
@ -88,8 +91,6 @@ def _run_test(
)
@pytest.mark.skipif(transformers.__version__ >= "4.46",
reason="Model broken with changes in transformers 4.46")
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])

View File

@ -293,16 +293,29 @@ class PixtralHFMultiModalProcessor(
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
images = mm_data["images"]
assert isinstance(images, list)
# Before/after https://github.com/huggingface/transformers/pull/35122
if Version(TRANSFORMERS_VERSION) <= Version("4.48.2"):
images = mm_data["images"]
assert isinstance(images, list)
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
assert (isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list)
and len(pixel_values) == 1)
assert (isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))
processed_outputs["pixel_values"] = pixel_values[0]
processed_outputs["pixel_values"] = pixel_values[0]
else:
# Avoid padding since we need the output for each image to be
# independent of other images for the cache to work correctly
image_sizes = processed_outputs["image_sizes"]
assert len(pixel_values) == len(image_sizes)
processed_outputs["pixel_values"] = [
p[:, :h, :w]
for p, (h, w) in zip(pixel_values, image_sizes)
]
return processed_outputs

View File

@ -73,7 +73,15 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
return self.ctx.get_hf_config(LlavaNextConfig)
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextProcessor)
hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor)
# In case patch_size is omitted from `processor_config.json`
# e.g. for E5-V: https://huggingface.co/royokong/e5-v
if hf_processor.patch_size is None:
patch_size = self.get_vision_encoder_info().get_patch_size()
hf_processor.patch_size = patch_size
return hf_processor
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
def get_num_image_tokens(

View File

@ -342,6 +342,15 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
**kwargs: object,
):
hf_processor = self.ctx.get_hf_processor()
# NumPy arrays are considered as Iterable but not Sequence in
# https://github.com/huggingface/transformers/blob/main/src/transformers/image_transforms.py#L428
image_processor = hf_processor.image_processor # type: ignore
for attr in ("mean", "std"):
val = getattr(image_processor, attr)
if isinstance(val, np.ndarray):
setattr(image_processor, attr, val.tolist())
return hf_processor
def get_image_processor(self):

View File

@ -141,9 +141,9 @@ Uses a list instead of a tensor if the dimensions of each element do not match.
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
"""Equality check between :data:`NestedTensors` objects."""
if isinstance(a, torch.Tensor):
return isinstance(b, torch.Tensor) and bool((a == b).all().item())
return isinstance(b, torch.Tensor) and torch.equal(a, b)
elif isinstance(b, torch.Tensor):
return isinstance(a, torch.Tensor) and bool((b == a).all().item())
return isinstance(a, torch.Tensor) and torch.equal(b, a)
if isinstance(a, list):
return (isinstance(b, list)