mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 17:35:28 +08:00
[CI/Build] Ensure compatability with Transformers v4.53 (#20541)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
11c0198615
commit
01cae37713
@ -34,7 +34,7 @@ opencv-python-headless >= 4.11.0 # required for video test
|
|||||||
datamodel_code_generator # required for minicpm3 test
|
datamodel_code_generator # required for minicpm3 test
|
||||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||||
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
||||||
transformers==4.52.4
|
transformers==4.53.2
|
||||||
tokenizers==0.21.1
|
tokenizers==0.21.1
|
||||||
huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads.
|
huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads.
|
||||||
schemathesis>=3.39.15 # Required for openai schema test.
|
schemathesis>=3.39.15 # Required for openai schema test.
|
||||||
|
|||||||
@ -800,7 +800,7 @@ tqdm==4.66.6
|
|||||||
# transformers
|
# transformers
|
||||||
tqdm-multiprocess==0.0.11
|
tqdm-multiprocess==0.0.11
|
||||||
# via lm-eval
|
# via lm-eval
|
||||||
transformers==4.52.4
|
transformers==4.53.2
|
||||||
# via
|
# via
|
||||||
# -r requirements/test.in
|
# -r requirements/test.in
|
||||||
# genai-perf
|
# genai-perf
|
||||||
|
|||||||
@ -318,6 +318,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
num_logprobs=10,
|
num_logprobs=10,
|
||||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||||
auto_cls=AutoModelForImageTextToText,
|
auto_cls=AutoModelForImageTextToText,
|
||||||
|
marks=[large_gpu_mark(min_gb=32)],
|
||||||
),
|
),
|
||||||
"glm4_1v-video": VLMTestInfo(
|
"glm4_1v-video": VLMTestInfo(
|
||||||
models=["THUDM/GLM-4.1V-9B-Thinking"],
|
models=["THUDM/GLM-4.1V-9B-Thinking"],
|
||||||
@ -331,8 +332,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
inputs=custom_inputs.video_with_metadata_glm4_1v(),
|
inputs=custom_inputs.video_with_metadata_glm4_1v(),
|
||||||
limit_mm_per_prompt={"video": 1},
|
limit_mm_per_prompt={"video": 1},
|
||||||
)],
|
)],
|
||||||
# This is needed to run on machine with 24GB VRAM
|
marks=[large_gpu_mark(min_gb=32)],
|
||||||
vllm_runner_kwargs={"gpu_memory_utilization": 0.95},
|
|
||||||
),
|
),
|
||||||
"h2ovl": VLMTestInfo(
|
"h2ovl": VLMTestInfo(
|
||||||
models = [
|
models = [
|
||||||
|
|||||||
@ -159,6 +159,7 @@ def _test_processing_correctness(
|
|||||||
_ADD_SPECIAL_TOKENS_OVERRIDES = {
|
_ADD_SPECIAL_TOKENS_OVERRIDES = {
|
||||||
"mllama": False,
|
"mllama": False,
|
||||||
"ovis": False,
|
"ovis": False,
|
||||||
|
"paligemma": False,
|
||||||
"ultravox": False,
|
"ultravox": False,
|
||||||
"whisper": False,
|
"whisper": False,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,7 +31,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
|||||||
model_info.check_transformers_version(on_fail="skip")
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
|
||||||
# FIXME: Possible memory leak in the previous tests?
|
# FIXME: Possible memory leak in the previous tests?
|
||||||
if model_arch in ("GraniteSpeechForConditionalGeneration",
|
if model_arch in ("Glm4vForConditionalGeneration",
|
||||||
|
"GraniteSpeechForConditionalGeneration",
|
||||||
"KimiVLForConditionalGeneration"):
|
"KimiVLForConditionalGeneration"):
|
||||||
pytest.skip("Avoid OOM")
|
pytest.skip("Avoid OOM")
|
||||||
|
|
||||||
@ -46,9 +47,14 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
|||||||
n_group = getattr(text_config, 'n_group', None)
|
n_group = getattr(text_config, 'n_group', None)
|
||||||
num_experts = n_group * 2 if n_group is not None else 2
|
num_experts = n_group * 2 if n_group is not None else 2
|
||||||
|
|
||||||
|
# we use three layers for Gemma-3n to check
|
||||||
|
# both normal layer and kv_shared_layer
|
||||||
|
num_hidden_layers = (3 if model_arch
|
||||||
|
== "Gemma3nForConditionalGeneration" else 1)
|
||||||
|
|
||||||
text_config.update({
|
text_config.update({
|
||||||
"num_layers": 1,
|
"num_layers": 1,
|
||||||
"num_hidden_layers": 1,
|
"num_hidden_layers": num_hidden_layers,
|
||||||
"num_experts": num_experts,
|
"num_experts": num_experts,
|
||||||
"num_experts_per_tok": 2,
|
"num_experts_per_tok": 2,
|
||||||
"num_local_experts": num_experts,
|
"num_local_experts": num_experts,
|
||||||
@ -56,6 +62,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
|||||||
"first_k_dense_replace": 0,
|
"first_k_dense_replace": 0,
|
||||||
# To avoid OOM on DeepSeek-V3
|
# To avoid OOM on DeepSeek-V3
|
||||||
"n_routed_experts": num_experts,
|
"n_routed_experts": num_experts,
|
||||||
|
# For Gemma-3n
|
||||||
|
"num_kv_shared_layers": 1,
|
||||||
})
|
})
|
||||||
|
|
||||||
if hasattr(hf_config, "vision_config"):
|
if hasattr(hf_config, "vision_config"):
|
||||||
|
|||||||
@ -5,9 +5,7 @@ from dataclasses import dataclass
|
|||||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging.version import Version
|
|
||||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from vllm.jsontree import JSONTree, json_map_leaves
|
from vllm.jsontree import JSONTree, json_map_leaves
|
||||||
@ -137,13 +135,9 @@ class InputProcessingContext(InputContext):
|
|||||||
/,
|
/,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> _P:
|
) -> _P:
|
||||||
# Transformers 4.53.0 has issue with passing tokenizer to
|
|
||||||
# initialize processor. We disable it for this version.
|
|
||||||
# See: https://github.com/vllm-project/vllm/issues/20224
|
|
||||||
if Version(TRANSFORMERS_VERSION) != Version("4.53.0"):
|
|
||||||
kwargs["tokenizer"] = self.tokenizer
|
|
||||||
return super().get_hf_processor(
|
return super().get_hf_processor(
|
||||||
typ,
|
typ,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -189,10 +189,13 @@ class CohereAttention(nn.Module):
|
|||||||
|
|
||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
layer_has_sliding_window = (
|
layer_has_sliding_window = (
|
||||||
getattr(config, "sliding_window_pattern", False)
|
getattr(config, "sliding_window_pattern", False) and
|
||||||
and (layer_idx + 1) % self.config.sliding_window_pattern != 0)
|
(layer_idx + 1) % self.config.sliding_window_pattern
|
||||||
|
!= 0) or (getattr(config, "layer_types", False)
|
||||||
|
and config.layer_types[layer_idx] == "sliding_attention")
|
||||||
|
|
||||||
self.sliding_window = (interleaved_sliding_window
|
self.sliding_window = (interleaved_sliding_window
|
||||||
|
or config.sliding_window
|
||||||
if layer_has_sliding_window else None)
|
if layer_has_sliding_window else None)
|
||||||
|
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
|
|||||||
@ -175,12 +175,21 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
|||||||
|
|
||||||
# Original output: (1, num_images, Pn, Px * Py * C)
|
# Original output: (1, num_images, Pn, Px * Py * C)
|
||||||
# New output: (num_images, Pn, Px * Py * C)
|
# New output: (num_images, Pn, Px * Py * C)
|
||||||
assert (isinstance(image_patches, list)
|
# image_patches is a list with shape:
|
||||||
and len(image_patches) == 1)
|
# (1, num_images, Pn, Px * Py * C)
|
||||||
assert (isinstance(image_patches[0], torch.Tensor)
|
# before Transformers 4.53
|
||||||
and len(image_patches[0]) == len(images))
|
if isinstance(image_patches, list):
|
||||||
|
assert len(image_patches) == 1
|
||||||
processed_outputs["image_patches"] = image_patches[0]
|
assert (isinstance(image_patches[0], torch.Tensor)
|
||||||
|
and len(image_patches[0]) == len(images))
|
||||||
|
processed_outputs["image_patches"] = image_patches[0]
|
||||||
|
# image_patches is a tensor with shape:
|
||||||
|
# (num_images, Pn, Px * Py * C)
|
||||||
|
# after Transformers 4.53
|
||||||
|
elif isinstance(image_patches, torch.Tensor):
|
||||||
|
assert len(image_patches) == len(images)
|
||||||
|
else:
|
||||||
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
@ -193,8 +202,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
|||||||
vocab = tokenizer.get_vocab()
|
vocab = tokenizer.get_vocab()
|
||||||
|
|
||||||
boa_token_id = vocab["<0x04>"]
|
boa_token_id = vocab["<0x04>"]
|
||||||
|
if prompt_tokens[-1] != boa_token_id:
|
||||||
|
prompt_tokens.append(boa_token_id)
|
||||||
|
|
||||||
return prompt_tokens + [boa_token_id]
|
return prompt_tokens
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -149,14 +149,17 @@ class Gemma3Attention(nn.Module):
|
|||||||
# TODO(woosuk): Add reference to the original HF implementation.
|
# TODO(woosuk): Add reference to the original HF implementation.
|
||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
self.is_sliding = (getattr(
|
self.is_sliding = (getattr(
|
||||||
config, "interleaved_sliding_window", None) is not None and bool(
|
config, "interleaved_sliding_window", None) is not None and (bool(
|
||||||
(layer_idx + 1) % config.sliding_window_pattern))
|
(layer_idx + 1) % config.sliding_window_pattern))) or (
|
||||||
|
getattr(config, "layer_types", None) is not None
|
||||||
|
and config.layer_types[layer_idx] == "sliding_attention")
|
||||||
# Initialize the rotary embedding.
|
# Initialize the rotary embedding.
|
||||||
if self.is_sliding:
|
if self.is_sliding:
|
||||||
# Local attention. Override the values in config.json.
|
# Local attention. Override the values in config.json.
|
||||||
self.rope_theta = config.rope_local_base_freq
|
self.rope_theta = config.rope_local_base_freq
|
||||||
self.rope_scaling = {"rope_type": "default"}
|
self.rope_scaling = {"rope_type": "default"}
|
||||||
self.sliding_window = config.interleaved_sliding_window
|
self.sliding_window = (config.interleaved_sliding_window
|
||||||
|
or config.sliding_window)
|
||||||
else:
|
else:
|
||||||
# Global attention. Use the values in config.json.
|
# Global attention. Use the values in config.json.
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
|
|||||||
@ -30,8 +30,10 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import BatchFeature, PretrainedConfig
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.whisper.modeling_whisper import (
|
from transformers.models.whisper.modeling_whisper import (ACT2FN,
|
||||||
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
|
WhisperAttention,
|
||||||
|
WhisperConfig,
|
||||||
|
WhisperEncoder)
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -378,14 +380,13 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
|||||||
def __init__(self, config: WhisperConfig, layer_idx: int):
|
def __init__(self, config: WhisperConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.d_model
|
self.embed_dim = config.d_model
|
||||||
self.self_attn = WHISPER_ATTENTION_CLASSES[
|
self.self_attn = WhisperAttention(
|
||||||
config._attn_implementation](
|
embed_dim=self.embed_dim,
|
||||||
embed_dim=self.embed_dim,
|
num_heads=config.encoder_attention_heads,
|
||||||
num_heads=config.encoder_attention_heads,
|
dropout=config.attention_dropout,
|
||||||
dropout=config.attention_dropout,
|
config=config,
|
||||||
config=config,
|
layer_idx=layer_idx,
|
||||||
layer_idx=layer_idx,
|
)
|
||||||
)
|
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
self.activation_fn = ACT2FN[config.activation_function]
|
self.activation_fn = ACT2FN[config.activation_function]
|
||||||
|
|||||||
@ -125,7 +125,7 @@ class PaliGemmaMultiModalProcessor(
|
|||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
if not mm_data:
|
if not mm_data:
|
||||||
prompt_ids = tokenizer.encode(prompt)
|
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||||
|
|
||||||
return super()._call_hf_processor(
|
return super()._call_hf_processor(
|
||||||
|
|||||||
@ -144,8 +144,16 @@ class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo,
|
|||||||
) -> Qwen2_5OmniProcessor:
|
) -> Qwen2_5OmniProcessor:
|
||||||
if fps is not None:
|
if fps is not None:
|
||||||
kwargs["fps"] = fps
|
kwargs["fps"] = fps
|
||||||
|
|
||||||
|
# Monkey patch for Transformers v4.53
|
||||||
|
processor_class = Qwen2_5OmniProcessor
|
||||||
|
if processor_class.image_processor_class != "AutoImageProcessor":
|
||||||
|
processor_class.image_processor_class = "AutoImageProcessor"
|
||||||
|
if processor_class.video_processor_class != "AutoVideoProcessor":
|
||||||
|
processor_class.video_processor_class = "AutoVideoProcessor"
|
||||||
|
|
||||||
processor = self.ctx.get_hf_processor(
|
processor = self.ctx.get_hf_processor(
|
||||||
Qwen2_5OmniProcessor,
|
processor_class,
|
||||||
image_processor=self.get_image_processor(min_pixels=min_pixels,
|
image_processor=self.get_image_processor(min_pixels=min_pixels,
|
||||||
max_pixels=max_pixels,
|
max_pixels=max_pixels,
|
||||||
size=size,
|
size=size,
|
||||||
|
|||||||
@ -634,7 +634,14 @@ class WhisperProcessingInfo(BaseProcessingInfo):
|
|||||||
def get_hf_processor(self,
|
def get_hf_processor(self,
|
||||||
sampling_rate: Optional[int] = None
|
sampling_rate: Optional[int] = None
|
||||||
) -> WhisperProcessor:
|
) -> WhisperProcessor:
|
||||||
return self.ctx.get_hf_processor(WhisperProcessor)
|
# HACK: Transformers 4.53.0 has issue with whisper tokenizer to
|
||||||
|
# initialize processor. We use a monkeypatch to fix it here.
|
||||||
|
# See: https://github.com/vllm-project/vllm/issues/20224
|
||||||
|
processor_class = WhisperProcessor
|
||||||
|
tokenizer_class = ("WhisperTokenizer", "WhisperTokenizerFast")
|
||||||
|
if processor_class.tokenizer_class != tokenizer_class:
|
||||||
|
processor_class.tokenizer_class = tokenizer_class
|
||||||
|
return self.ctx.get_hf_processor(processor_class)
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"audio": 1}
|
return {"audio": 1}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user