[Misc] Add tensor schema test coverage for multimodal models (#21754)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-08-03 15:52:14 +08:00 committed by GitHub
parent 337eb23bcc
commit 3dddbf1f25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 222 additions and 15 deletions

View File

@ -581,7 +581,8 @@ steps:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pip freeze | grep -E 'torch'
- pytest -v -s models/multimodal/processing
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model
- pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn"
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Models Test (Extended) 1

View File

@ -775,7 +775,7 @@ class VllmRunner:
tokenizer_mode: str = "auto",
trust_remote_code: bool = True,
seed: Optional[int] = 0,
max_model_len: int = 1024,
max_model_len: Optional[int] = 1024,
dtype: str = "auto",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,

View File

@ -0,0 +1,199 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import partial
from typing import Any
from unittest.mock import patch
import pytest
from transformers import PretrainedConfig
from vllm.config import ModelConfig
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import GiB_bytes, set_default_torch_num_threads
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.engine.core import EngineCore as V1EngineCore
from ...conftest import VllmRunner
from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements",
"MiniMaxVL01ForConditionalGeneration": "broken model",
}
def create_batched_mm_kwargs(
model_config: ModelConfig,
processor: BaseMultiModalProcessor,
) -> MultiModalKwargs:
processing_info = processor.info
dummy_inputs = processor.dummy_inputs
supported_mm_limits = processing_info.get_supported_mm_limits()
mm_counts = {
modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items()
}
processor_inputs = dummy_inputs.get_dummy_processor_inputs(
seq_len=model_config.max_model_len,
mm_counts=mm_counts,
)
mm_kwargs = processor.apply(
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)["mm_kwargs"]
mm_kwargs = MultiModalKwargs.batch([mm_kwargs])
return mm_kwargs
# Avoid OOM and reduce initialization time by only using 1 layer
def hf_overrides(hf_config: PretrainedConfig,
exist_overrides: dict[str, Any]) -> PretrainedConfig:
hf_config.update(exist_overrides)
text_config = hf_config.get_text_config()
# Ensure at least 2 expert per group
# Since `grouped_topk` assumes top-2
n_group = getattr(text_config, 'n_group', None)
num_experts = n_group * 2 if n_group is not None else 2
# we use three layers for Gemma-3n to check
# both normal layer and kv_shared_layer
text_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
# Otherwise there will not be any expert layers
"first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts,
# For Gemma-3n
"num_kv_shared_layers": 1,
})
if hasattr(hf_config, "vision_config"):
hf_config.vision_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
# e.g.: ibm-granite/granite-speech-3.3-2b
if hasattr(hf_config, "encoder_config"):
hf_config.encoder_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
# e.g.: Qwen/Qwen2-Audio-7B-Instruct
if hasattr(hf_config, "audio_config"):
hf_config.audio_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
"encoder_layers": 1,
})
return hf_config
@pytest.mark.core_model
@pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys()))
def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
monkeypatch):
if model_arch in ARCH_TO_SKIP:
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_available_online(on_fail="skip")
model_id = model_info.default
hf_overrides_fn = partial(hf_overrides,
exist_overrides=model_info.hf_overrides)
model_config = ModelConfig(
model_id,
tokenizer=model_info.tokenizer or model_id,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
if not any(
hasattr(model_cls, f"_parse_and_validate_{m}_input")
for m in ["image", "video", "audio"]):
pytest.skip(f"{model_arch} does not support tensor schema validation.")
ctx = InputProcessingContext(
model_config,
tokenizer=cached_tokenizer_from_config(model_config),
)
processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits()
limit_mm_per_prompt = {
modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items()
}
# Avoid calling model.forward()
def _initialize_kv_caches_v0(self) -> None:
self.cache_config.num_gpu_blocks = 0
self.cache_config.num_cpu_blocks = 0
def _initialize_kv_caches_v1(self, vllm_config):
kv_cache_specs = self.model_executor.get_kv_cache_specs()
scheduler_kv_cache_config = get_kv_cache_config(
vllm_config,
kv_cache_specs[0],
10 * GiB_bytes,
)
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return 1, 0, scheduler_kv_cache_config
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
_initialize_kv_caches_v0),
patch.object(V1EngineCore, "_initialize_kv_caches",
_initialize_kv_caches_v1), monkeypatch.context() as m):
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")
with (
set_default_torch_num_threads(1),
vllm_runner(
model_id,
tokenizer_name=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
max_model_len=model_info.max_model_len,
load_format="dummy",
hf_overrides=hf_overrides_fn,
limit_mm_per_prompt=limit_mm_per_prompt,
enforce_eager=True,
) as vllm_model,
):
model_config = vllm_model.llm.llm_engine.model_config
llm_engine = vllm_model.llm.llm_engine
if hasattr(llm_engine, "processor"):
# v1 processor
mm_registry = llm_engine.processor.mm_registry
else:
# v0 input_preprocessor
mm_registry = llm_engine.input_preprocessor.mm_registry
processor = mm_registry.create_processor(model_config)
mm_kwargs = create_batched_mm_kwargs(model_config, processor)
def validate_model_input(model):
for modality in ("audio", "image", "video"):
method_name = f"_parse_and_validate_{modality}_input"
if hasattr(model, method_name):
getattr(model, method_name)(**mm_kwargs)
vllm_model.apply_model(validate_model_input)

View File

@ -383,6 +383,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V-Air",
is_available_online=False), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
trust_remote_code=True,
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible."), # noqa: E501
@ -432,6 +433,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True),
"Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501
trust_remote_code=True),
"Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
@ -439,9 +443,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_transformers_version="4.48",
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
"Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
"Phi4MultimodalForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", # noqa: E501

View File

@ -51,13 +51,14 @@ class DeepseekVL2ImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- p: Number of patches
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", 3, "h", "w")]
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]

View File

@ -104,13 +104,16 @@ def smart_resize(
class KeyeImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- cps: Number of channels * patch_size * patch_size
- c: Number of channels
- ps: Patch size
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values"]
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
pixel_values: Annotated[torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
@ -134,14 +137,16 @@ KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
class KeyeVideoPixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- ctps: Number of channels * temporal_patch_size * patch_size *
patch_size
- nv: Number of videos
- c: Number of channels
- ps: Patch size
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctps")]
pixel_values_videos: Annotated[torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]

View File

@ -256,7 +256,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
def __call__(
self,
*,
prompt: str,
text: str,
images: list[Image.Image],
inference_mode: bool = True,
**kwargs,
@ -264,7 +264,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
"""
Args:
prompt (str): the formatted prompt;
text (str): the formatted prompt;
images (list[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
**kwargs:
@ -278,7 +278,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
"""
prepare = self.process_one(
prompt=prompt,
prompt=text,
images=images,
inference_mode=inference_mode,
)