mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
[VLM] Update compatibility with transformers 4.49
This commit is contained in:
parent
bf3b79efb8
commit
75404d041b
@ -883,8 +883,7 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingface.co/mistral-community/pixtral-12b/discussions/22)).
|
||||
A corrected version is available at <gh-file:examples/template_pixtral_hf.jinja>.
|
||||
`mistral-community/pixtral-12b` does not support V1 yet.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
|
||||
{{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}
|
||||
{%- endif %}
|
||||
{%- if message["role"] == "user" %}
|
||||
{%- if loop.last and system_message is defined %}
|
||||
{{- "[INST]" + system_message + "\n" }}
|
||||
{%- else %}
|
||||
{{- "[INST]" }}
|
||||
{%- endif %}
|
||||
{%- if message["content"] is not string %}
|
||||
{%- for chunk in message["content"] %}
|
||||
{%- if chunk["type"] == "text" %}
|
||||
{{- chunk["text"] }}
|
||||
{%- elif chunk["type"] == "image" %}
|
||||
{{- "[IMG]" }}
|
||||
{%- else %}
|
||||
{{- raise_exception("Unrecognized content type!") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- else %}
|
||||
{{- message["content"] }}
|
||||
{%- endif %}
|
||||
{{- "[/INST]" }}
|
||||
{%- elif message["role"] == "assistant" %}
|
||||
{{- message["content"] + eos_token}}
|
||||
{%- else %}
|
||||
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
@ -761,7 +761,6 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
("template_falcon.jinja", "string"),
|
||||
("template_inkbot.jinja", "string"),
|
||||
("template_llava.jinja", "string"),
|
||||
("template_pixtral_hf.jinja", "openai"),
|
||||
("template_vlm2vec.jinja", "openai"),
|
||||
("tool_chat_template_granite_20b_fc.jinja", "string"),
|
||||
("tool_chat_template_hermes.jinja", "string"),
|
||||
|
||||
@ -224,7 +224,7 @@ VLM_TEST_SETTINGS = {
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_VERSION) >= Version("4.48"),
|
||||
reason="HF model is not compatible with transformers>=4.48.0",
|
||||
reason="HF model is not compatible with transformers>=4.48",
|
||||
)
|
||||
],
|
||||
),
|
||||
@ -359,7 +359,7 @@ VLM_TEST_SETTINGS = {
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_VERSION) >= Version("4.48"),
|
||||
reason="HF model is not compatible with transformers>=4.48.0",
|
||||
reason="HF model is not compatible with transformers>=4.48",
|
||||
)
|
||||
],
|
||||
),
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import List, Type
|
||||
|
||||
import pytest
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from transformers import AutoModelForVision2Seq
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||
@ -57,6 +56,10 @@ def _run_test(
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||
# Patch the issue where generation_config.json is missing
|
||||
hf_model.processor.patch_size = \
|
||||
hf_model.model.config.vision_config.patch_size
|
||||
|
||||
# Patch the issue where image_token_id
|
||||
# exceeds the maximum allowed vocab size
|
||||
hf_model.model.resize_token_embeddings(
|
||||
@ -88,8 +91,6 @@ def _run_test(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(transformers.__version__ >= "4.46",
|
||||
reason="Model broken with changes in transformers 4.46")
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
|
||||
@ -293,16 +293,29 @@ class PixtralHFMultiModalProcessor(
|
||||
|
||||
pixel_values = processed_outputs.get("pixel_values")
|
||||
if pixel_values is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
# Before/after https://github.com/huggingface/transformers/pull/35122
|
||||
if Version(TRANSFORMERS_VERSION) <= Version("4.48.2"):
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
|
||||
assert (isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list)
|
||||
and len(pixel_values) == 1)
|
||||
assert (isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
else:
|
||||
# Avoid padding since we need the output for each image to be
|
||||
# independent of other images for the cache to work correctly
|
||||
image_sizes = processed_outputs["image_sizes"]
|
||||
assert len(pixel_values) == len(image_sizes)
|
||||
|
||||
processed_outputs["pixel_values"] = [
|
||||
p[:, :h, :w]
|
||||
for p, (h, w) in zip(pixel_values, image_sizes)
|
||||
]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
|
||||
@ -73,7 +73,15 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
|
||||
return self.ctx.get_hf_config(LlavaNextConfig)
|
||||
|
||||
def get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaNextProcessor)
|
||||
hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor)
|
||||
|
||||
# In case patch_size is omitted from `processor_config.json`
|
||||
# e.g. for E5-V: https://huggingface.co/royokong/e5-v
|
||||
if hf_processor.patch_size is None:
|
||||
patch_size = self.get_vision_encoder_info().get_patch_size()
|
||||
hf_processor.patch_size = patch_size
|
||||
|
||||
return hf_processor
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
|
||||
def get_num_image_tokens(
|
||||
|
||||
@ -342,6 +342,15 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
**kwargs: object,
|
||||
):
|
||||
hf_processor = self.ctx.get_hf_processor()
|
||||
|
||||
# NumPy arrays are considered as Iterable but not Sequence in
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/image_transforms.py#L428
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
for attr in ("mean", "std"):
|
||||
val = getattr(image_processor, attr)
|
||||
if isinstance(val, np.ndarray):
|
||||
setattr(image_processor, attr, val.tolist())
|
||||
|
||||
return hf_processor
|
||||
|
||||
def get_image_processor(self):
|
||||
|
||||
@ -141,9 +141,9 @@ Uses a list instead of a tensor if the dimensions of each element do not match.
|
||||
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
|
||||
"""Equality check between :data:`NestedTensors` objects."""
|
||||
if isinstance(a, torch.Tensor):
|
||||
return isinstance(b, torch.Tensor) and bool((a == b).all().item())
|
||||
return isinstance(b, torch.Tensor) and torch.equal(a, b)
|
||||
elif isinstance(b, torch.Tensor):
|
||||
return isinstance(a, torch.Tensor) and bool((b == a).all().item())
|
||||
return isinstance(a, torch.Tensor) and torch.equal(b, a)
|
||||
|
||||
if isinstance(a, list):
|
||||
return (isinstance(b, list)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user