mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:45:01 +08:00
[CI/Build] Ensure compatability with Transformers v4.53 (#20541)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
11c0198615
commit
01cae37713
@ -34,7 +34,7 @@ opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.52.4
|
||||
transformers==4.53.2
|
||||
tokenizers==0.21.1
|
||||
huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads.
|
||||
schemathesis>=3.39.15 # Required for openai schema test.
|
||||
|
||||
@ -800,7 +800,7 @@ tqdm==4.66.6
|
||||
# transformers
|
||||
tqdm-multiprocess==0.0.11
|
||||
# via lm-eval
|
||||
transformers==4.52.4
|
||||
transformers==4.53.2
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# genai-perf
|
||||
|
||||
@ -318,6 +318,7 @@ VLM_TEST_SETTINGS = {
|
||||
num_logprobs=10,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"glm4_1v-video": VLMTestInfo(
|
||||
models=["THUDM/GLM-4.1V-9B-Thinking"],
|
||||
@ -331,8 +332,7 @@ VLM_TEST_SETTINGS = {
|
||||
inputs=custom_inputs.video_with_metadata_glm4_1v(),
|
||||
limit_mm_per_prompt={"video": 1},
|
||||
)],
|
||||
# This is needed to run on machine with 24GB VRAM
|
||||
vllm_runner_kwargs={"gpu_memory_utilization": 0.95},
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"h2ovl": VLMTestInfo(
|
||||
models = [
|
||||
|
||||
@ -159,6 +159,7 @@ def _test_processing_correctness(
|
||||
_ADD_SPECIAL_TOKENS_OVERRIDES = {
|
||||
"mllama": False,
|
||||
"ovis": False,
|
||||
"paligemma": False,
|
||||
"ultravox": False,
|
||||
"whisper": False,
|
||||
}
|
||||
|
||||
@ -31,7 +31,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
# FIXME: Possible memory leak in the previous tests?
|
||||
if model_arch in ("GraniteSpeechForConditionalGeneration",
|
||||
if model_arch in ("Glm4vForConditionalGeneration",
|
||||
"GraniteSpeechForConditionalGeneration",
|
||||
"KimiVLForConditionalGeneration"):
|
||||
pytest.skip("Avoid OOM")
|
||||
|
||||
@ -46,9 +47,14 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
||||
n_group = getattr(text_config, 'n_group', None)
|
||||
num_experts = n_group * 2 if n_group is not None else 2
|
||||
|
||||
# we use three layers for Gemma-3n to check
|
||||
# both normal layer and kv_shared_layer
|
||||
num_hidden_layers = (3 if model_arch
|
||||
== "Gemma3nForConditionalGeneration" else 1)
|
||||
|
||||
text_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"num_experts": num_experts,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": num_experts,
|
||||
@ -56,6 +62,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
||||
"first_k_dense_replace": 0,
|
||||
# To avoid OOM on DeepSeek-V3
|
||||
"n_routed_experts": num_experts,
|
||||
# For Gemma-3n
|
||||
"num_kv_shared_layers": 1,
|
||||
})
|
||||
|
||||
if hasattr(hf_config, "vision_config"):
|
||||
|
||||
@ -5,9 +5,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
@ -137,13 +135,9 @@ class InputProcessingContext(InputContext):
|
||||
/,
|
||||
**kwargs: object,
|
||||
) -> _P:
|
||||
# Transformers 4.53.0 has issue with passing tokenizer to
|
||||
# initialize processor. We disable it for this version.
|
||||
# See: https://github.com/vllm-project/vllm/issues/20224
|
||||
if Version(TRANSFORMERS_VERSION) != Version("4.53.0"):
|
||||
kwargs["tokenizer"] = self.tokenizer
|
||||
return super().get_hf_processor(
|
||||
typ,
|
||||
tokenizer=self.tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -189,10 +189,13 @@ class CohereAttention(nn.Module):
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
layer_has_sliding_window = (
|
||||
getattr(config, "sliding_window_pattern", False)
|
||||
and (layer_idx + 1) % self.config.sliding_window_pattern != 0)
|
||||
getattr(config, "sliding_window_pattern", False) and
|
||||
(layer_idx + 1) % self.config.sliding_window_pattern
|
||||
!= 0) or (getattr(config, "layer_types", False)
|
||||
and config.layer_types[layer_idx] == "sliding_attention")
|
||||
|
||||
self.sliding_window = (interleaved_sliding_window
|
||||
or config.sliding_window
|
||||
if layer_has_sliding_window else None)
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
|
||||
@ -175,12 +175,21 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
|
||||
# Original output: (1, num_images, Pn, Px * Py * C)
|
||||
# New output: (num_images, Pn, Px * Py * C)
|
||||
assert (isinstance(image_patches, list)
|
||||
and len(image_patches) == 1)
|
||||
assert (isinstance(image_patches[0], torch.Tensor)
|
||||
and len(image_patches[0]) == len(images))
|
||||
|
||||
processed_outputs["image_patches"] = image_patches[0]
|
||||
# image_patches is a list with shape:
|
||||
# (1, num_images, Pn, Px * Py * C)
|
||||
# before Transformers 4.53
|
||||
if isinstance(image_patches, list):
|
||||
assert len(image_patches) == 1
|
||||
assert (isinstance(image_patches[0], torch.Tensor)
|
||||
and len(image_patches[0]) == len(images))
|
||||
processed_outputs["image_patches"] = image_patches[0]
|
||||
# image_patches is a tensor with shape:
|
||||
# (num_images, Pn, Px * Py * C)
|
||||
# after Transformers 4.53
|
||||
elif isinstance(image_patches, torch.Tensor):
|
||||
assert len(image_patches) == len(images)
|
||||
else:
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@ -193,8 +202,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
boa_token_id = vocab["<0x04>"]
|
||||
if prompt_tokens[-1] != boa_token_id:
|
||||
prompt_tokens.append(boa_token_id)
|
||||
|
||||
return prompt_tokens + [boa_token_id]
|
||||
return prompt_tokens
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
|
||||
@ -149,14 +149,17 @@ class Gemma3Attention(nn.Module):
|
||||
# TODO(woosuk): Add reference to the original HF implementation.
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.is_sliding = (getattr(
|
||||
config, "interleaved_sliding_window", None) is not None and bool(
|
||||
(layer_idx + 1) % config.sliding_window_pattern))
|
||||
config, "interleaved_sliding_window", None) is not None and (bool(
|
||||
(layer_idx + 1) % config.sliding_window_pattern))) or (
|
||||
getattr(config, "layer_types", None) is not None
|
||||
and config.layer_types[layer_idx] == "sliding_attention")
|
||||
# Initialize the rotary embedding.
|
||||
if self.is_sliding:
|
||||
# Local attention. Override the values in config.json.
|
||||
self.rope_theta = config.rope_local_base_freq
|
||||
self.rope_scaling = {"rope_type": "default"}
|
||||
self.sliding_window = config.interleaved_sliding_window
|
||||
self.sliding_window = (config.interleaved_sliding_window
|
||||
or config.sliding_window)
|
||||
else:
|
||||
# Global attention. Use the values in config.json.
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
@ -30,8 +30,10 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.whisper.modeling_whisper import (
|
||||
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
|
||||
from transformers.models.whisper.modeling_whisper import (ACT2FN,
|
||||
WhisperAttention,
|
||||
WhisperConfig,
|
||||
WhisperEncoder)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -378,14 +380,13 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
||||
def __init__(self, config: WhisperConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = WHISPER_ATTENTION_CLASSES[
|
||||
config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.self_attn = WhisperAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
|
||||
@ -125,7 +125,7 @@ class PaliGemmaMultiModalProcessor(
|
||||
) -> BatchFeature:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
if not mm_data:
|
||||
prompt_ids = tokenizer.encode(prompt)
|
||||
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
return super()._call_hf_processor(
|
||||
|
||||
@ -144,8 +144,16 @@ class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo,
|
||||
) -> Qwen2_5OmniProcessor:
|
||||
if fps is not None:
|
||||
kwargs["fps"] = fps
|
||||
|
||||
# Monkey patch for Transformers v4.53
|
||||
processor_class = Qwen2_5OmniProcessor
|
||||
if processor_class.image_processor_class != "AutoImageProcessor":
|
||||
processor_class.image_processor_class = "AutoImageProcessor"
|
||||
if processor_class.video_processor_class != "AutoVideoProcessor":
|
||||
processor_class.video_processor_class = "AutoVideoProcessor"
|
||||
|
||||
processor = self.ctx.get_hf_processor(
|
||||
Qwen2_5OmniProcessor,
|
||||
processor_class,
|
||||
image_processor=self.get_image_processor(min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
size=size,
|
||||
|
||||
@ -634,7 +634,14 @@ class WhisperProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_processor(self,
|
||||
sampling_rate: Optional[int] = None
|
||||
) -> WhisperProcessor:
|
||||
return self.ctx.get_hf_processor(WhisperProcessor)
|
||||
# HACK: Transformers 4.53.0 has issue with whisper tokenizer to
|
||||
# initialize processor. We use a monkeypatch to fix it here.
|
||||
# See: https://github.com/vllm-project/vllm/issues/20224
|
||||
processor_class = WhisperProcessor
|
||||
tokenizer_class = ("WhisperTokenizer", "WhisperTokenizerFast")
|
||||
if processor_class.tokenizer_class != tokenizer_class:
|
||||
processor_class.tokenizer_class = tokenizer_class
|
||||
return self.ctx.get_hf_processor(processor_class)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": 1}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user