mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 02:45:01 +08:00
[Misc] Add tensor schema test coverage for multimodal models (#21754)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
337eb23bcc
commit
3dddbf1f25
@ -581,7 +581,8 @@ steps:
|
|||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
- pip freeze | grep -E 'torch'
|
- pip freeze | grep -E 'torch'
|
||||||
- pytest -v -s models/multimodal/processing
|
- pytest -v -s models/multimodal/processing
|
||||||
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model
|
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model
|
||||||
|
- pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn"
|
||||||
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||||
|
|
||||||
- label: Multi-Modal Models Test (Extended) 1
|
- label: Multi-Modal Models Test (Extended) 1
|
||||||
|
|||||||
@ -775,7 +775,7 @@ class VllmRunner:
|
|||||||
tokenizer_mode: str = "auto",
|
tokenizer_mode: str = "auto",
|
||||||
trust_remote_code: bool = True,
|
trust_remote_code: bool = True,
|
||||||
seed: Optional[int] = 0,
|
seed: Optional[int] = 0,
|
||||||
max_model_len: int = 1024,
|
max_model_len: Optional[int] = 1024,
|
||||||
dtype: str = "auto",
|
dtype: str = "auto",
|
||||||
disable_log_stats: bool = True,
|
disable_log_stats: bool = True,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
|
|||||||
199
tests/models/multimodal/test_tensor_schema.py
Normal file
199
tests/models/multimodal/test_tensor_schema.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
|
||||||
|
from vllm.inputs import InputProcessingContext
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
|
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||||
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
|
from vllm.utils import GiB_bytes, set_default_torch_num_threads
|
||||||
|
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||||
|
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||||
|
|
||||||
|
from ...conftest import VllmRunner
|
||||||
|
from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
||||||
|
|
||||||
|
ARCH_TO_SKIP = {
|
||||||
|
"MolmoForCausalLM": "incompatible requirements",
|
||||||
|
"MiniMaxVL01ForConditionalGeneration": "broken model",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_batched_mm_kwargs(
|
||||||
|
model_config: ModelConfig,
|
||||||
|
processor: BaseMultiModalProcessor,
|
||||||
|
) -> MultiModalKwargs:
|
||||||
|
processing_info = processor.info
|
||||||
|
dummy_inputs = processor.dummy_inputs
|
||||||
|
supported_mm_limits = processing_info.get_supported_mm_limits()
|
||||||
|
mm_counts = {
|
||||||
|
modality: 3 if limit is None else limit
|
||||||
|
for modality, limit in supported_mm_limits.items()
|
||||||
|
}
|
||||||
|
processor_inputs = dummy_inputs.get_dummy_processor_inputs(
|
||||||
|
seq_len=model_config.max_model_len,
|
||||||
|
mm_counts=mm_counts,
|
||||||
|
)
|
||||||
|
mm_kwargs = processor.apply(
|
||||||
|
prompt=processor_inputs.prompt,
|
||||||
|
mm_data=processor_inputs.mm_data,
|
||||||
|
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||||
|
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||||
|
)["mm_kwargs"]
|
||||||
|
mm_kwargs = MultiModalKwargs.batch([mm_kwargs])
|
||||||
|
return mm_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
# Avoid OOM and reduce initialization time by only using 1 layer
|
||||||
|
def hf_overrides(hf_config: PretrainedConfig,
|
||||||
|
exist_overrides: dict[str, Any]) -> PretrainedConfig:
|
||||||
|
hf_config.update(exist_overrides)
|
||||||
|
text_config = hf_config.get_text_config()
|
||||||
|
# Ensure at least 2 expert per group
|
||||||
|
# Since `grouped_topk` assumes top-2
|
||||||
|
n_group = getattr(text_config, 'n_group', None)
|
||||||
|
num_experts = n_group * 2 if n_group is not None else 2
|
||||||
|
# we use three layers for Gemma-3n to check
|
||||||
|
# both normal layer and kv_shared_layer
|
||||||
|
text_config.update({
|
||||||
|
"num_layers": 1,
|
||||||
|
"num_hidden_layers": 1,
|
||||||
|
"num_experts": num_experts,
|
||||||
|
"num_experts_per_tok": 2,
|
||||||
|
"num_local_experts": num_experts,
|
||||||
|
# Otherwise there will not be any expert layers
|
||||||
|
"first_k_dense_replace": 0,
|
||||||
|
# To avoid OOM on DeepSeek-V3
|
||||||
|
"n_routed_experts": num_experts,
|
||||||
|
# For Gemma-3n
|
||||||
|
"num_kv_shared_layers": 1,
|
||||||
|
})
|
||||||
|
if hasattr(hf_config, "vision_config"):
|
||||||
|
hf_config.vision_config.update({
|
||||||
|
"num_layers": 1,
|
||||||
|
"num_hidden_layers": 1,
|
||||||
|
})
|
||||||
|
# e.g.: ibm-granite/granite-speech-3.3-2b
|
||||||
|
if hasattr(hf_config, "encoder_config"):
|
||||||
|
hf_config.encoder_config.update({
|
||||||
|
"num_layers": 1,
|
||||||
|
"num_hidden_layers": 1,
|
||||||
|
})
|
||||||
|
# e.g.: Qwen/Qwen2-Audio-7B-Instruct
|
||||||
|
if hasattr(hf_config, "audio_config"):
|
||||||
|
hf_config.audio_config.update({
|
||||||
|
"num_layers": 1,
|
||||||
|
"num_hidden_layers": 1,
|
||||||
|
"encoder_layers": 1,
|
||||||
|
})
|
||||||
|
return hf_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.core_model
|
||||||
|
@pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys()))
|
||||||
|
def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
|
||||||
|
monkeypatch):
|
||||||
|
if model_arch in ARCH_TO_SKIP:
|
||||||
|
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
|
||||||
|
|
||||||
|
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
|
||||||
|
model_id = model_info.default
|
||||||
|
|
||||||
|
hf_overrides_fn = partial(hf_overrides,
|
||||||
|
exist_overrides=model_info.hf_overrides)
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model_id,
|
||||||
|
tokenizer=model_info.tokenizer or model_id,
|
||||||
|
tokenizer_mode=model_info.tokenizer_mode,
|
||||||
|
revision=model_info.revision,
|
||||||
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
|
hf_overrides=model_info.hf_overrides,
|
||||||
|
)
|
||||||
|
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||||
|
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||||
|
|
||||||
|
if not any(
|
||||||
|
hasattr(model_cls, f"_parse_and_validate_{m}_input")
|
||||||
|
for m in ["image", "video", "audio"]):
|
||||||
|
pytest.skip(f"{model_arch} does not support tensor schema validation.")
|
||||||
|
|
||||||
|
ctx = InputProcessingContext(
|
||||||
|
model_config,
|
||||||
|
tokenizer=cached_tokenizer_from_config(model_config),
|
||||||
|
)
|
||||||
|
processing_info = factories.info(ctx)
|
||||||
|
supported_mm_limits = processing_info.get_supported_mm_limits()
|
||||||
|
limit_mm_per_prompt = {
|
||||||
|
modality: 3 if limit is None else limit
|
||||||
|
for modality, limit in supported_mm_limits.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Avoid calling model.forward()
|
||||||
|
def _initialize_kv_caches_v0(self) -> None:
|
||||||
|
self.cache_config.num_gpu_blocks = 0
|
||||||
|
self.cache_config.num_cpu_blocks = 0
|
||||||
|
|
||||||
|
def _initialize_kv_caches_v1(self, vllm_config):
|
||||||
|
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||||
|
scheduler_kv_cache_config = get_kv_cache_config(
|
||||||
|
vllm_config,
|
||||||
|
kv_cache_specs[0],
|
||||||
|
10 * GiB_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
|
||||||
|
return 1, 0, scheduler_kv_cache_config
|
||||||
|
|
||||||
|
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
|
||||||
|
_initialize_kv_caches_v0),
|
||||||
|
patch.object(V1EngineCore, "_initialize_kv_caches",
|
||||||
|
_initialize_kv_caches_v1), monkeypatch.context() as m):
|
||||||
|
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
if model_info.v0_only:
|
||||||
|
m.setenv("VLLM_USE_V1", "0")
|
||||||
|
|
||||||
|
with (
|
||||||
|
set_default_torch_num_threads(1),
|
||||||
|
vllm_runner(
|
||||||
|
model_id,
|
||||||
|
tokenizer_name=model_info.tokenizer,
|
||||||
|
tokenizer_mode=model_info.tokenizer_mode,
|
||||||
|
revision=model_info.revision,
|
||||||
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
|
max_model_len=model_info.max_model_len,
|
||||||
|
load_format="dummy",
|
||||||
|
hf_overrides=hf_overrides_fn,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
enforce_eager=True,
|
||||||
|
) as vllm_model,
|
||||||
|
):
|
||||||
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
|
llm_engine = vllm_model.llm.llm_engine
|
||||||
|
|
||||||
|
if hasattr(llm_engine, "processor"):
|
||||||
|
# v1 processor
|
||||||
|
mm_registry = llm_engine.processor.mm_registry
|
||||||
|
else:
|
||||||
|
# v0 input_preprocessor
|
||||||
|
mm_registry = llm_engine.input_preprocessor.mm_registry
|
||||||
|
|
||||||
|
processor = mm_registry.create_processor(model_config)
|
||||||
|
mm_kwargs = create_batched_mm_kwargs(model_config, processor)
|
||||||
|
|
||||||
|
def validate_model_input(model):
|
||||||
|
for modality in ("audio", "image", "video"):
|
||||||
|
method_name = f"_parse_and_validate_{modality}_input"
|
||||||
|
if hasattr(model, method_name):
|
||||||
|
getattr(model, method_name)(**mm_kwargs)
|
||||||
|
|
||||||
|
vllm_model.apply_model(validate_model_input)
|
||||||
@ -383,6 +383,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V-Air",
|
"Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V-Air",
|
||||||
is_available_online=False), # noqa: E501
|
is_available_online=False), # noqa: E501
|
||||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
|
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
|
||||||
|
trust_remote_code=True,
|
||||||
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
|
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
|
||||||
max_transformers_version="4.48", # noqa: E501
|
max_transformers_version="4.48", # noqa: E501
|
||||||
transformers_version_reason="HF model is not compatible."), # noqa: E501
|
transformers_version_reason="HF model is not compatible."), # noqa: E501
|
||||||
@ -432,6 +433,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501
|
"Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
"Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
|
||||||
|
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
|
||||||
|
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
|
||||||
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
|
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
|
||||||
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
|
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
|
||||||
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
||||||
@ -439,9 +443,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
max_transformers_version="4.48",
|
max_transformers_version="4.48",
|
||||||
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
|
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
|
||||||
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
|
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
|
||||||
"Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
|
|
||||||
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
|
|
||||||
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
|
|
||||||
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
|
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"Phi4MultimodalForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", # noqa: E501
|
"Phi4MultimodalForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", # noqa: E501
|
||||||
|
|||||||
@ -51,13 +51,14 @@ class DeepseekVL2ImagePixelInputs(TensorSchema):
|
|||||||
"""
|
"""
|
||||||
Dimensions:
|
Dimensions:
|
||||||
- bn: Batch size * number of images
|
- bn: Batch size * number of images
|
||||||
|
- p: Number of patches
|
||||||
- c: Number of channels (3)
|
- c: Number of channels (3)
|
||||||
- h: Height of each image
|
- h: Height of each image
|
||||||
- w: Width of each image
|
- w: Width of each image
|
||||||
"""
|
"""
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||||
TensorShape("bn", 3, "h", "w")]
|
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})]
|
||||||
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
|
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -104,13 +104,16 @@ def smart_resize(
|
|||||||
class KeyeImagePixelInputs(TensorSchema):
|
class KeyeImagePixelInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
Dimensions:
|
Dimensions:
|
||||||
|
- b: Batch size
|
||||||
- np: Number of patches
|
- np: Number of patches
|
||||||
- cps: Number of channels * patch_size * patch_size
|
- c: Number of channels
|
||||||
|
- ps: Patch size
|
||||||
- ni: Number of images
|
- ni: Number of images
|
||||||
- g: Grid dimensions (3 for t, h, w)
|
- g: Grid dimensions (3 for t, h, w)
|
||||||
"""
|
"""
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
|
pixel_values: Annotated[torch.Tensor,
|
||||||
|
TensorShape("b", "np", 3, "ps", "ps")]
|
||||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||||
|
|
||||||
|
|
||||||
@ -134,14 +137,16 @@ KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
|
|||||||
class KeyeVideoPixelInputs(TensorSchema):
|
class KeyeVideoPixelInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
Dimensions:
|
Dimensions:
|
||||||
|
- b: Batch size
|
||||||
- np: Number of patches
|
- np: Number of patches
|
||||||
- ctps: Number of channels * temporal_patch_size * patch_size *
|
- c: Number of channels
|
||||||
patch_size
|
- ps: Patch size
|
||||||
- nv: Number of videos
|
- ni: Number of images
|
||||||
- g: Grid dimensions (3 for t, h, w)
|
- g: Grid dimensions (3 for t, h, w)
|
||||||
"""
|
"""
|
||||||
type: Literal["pixel_values_videos"]
|
type: Literal["pixel_values_videos"]
|
||||||
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctps")]
|
pixel_values_videos: Annotated[torch.Tensor,
|
||||||
|
TensorShape("b", "np", 3, "ps", "ps")]
|
||||||
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -256,7 +256,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
prompt: str,
|
text: str,
|
||||||
images: list[Image.Image],
|
images: list[Image.Image],
|
||||||
inference_mode: bool = True,
|
inference_mode: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -264,7 +264,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (str): the formatted prompt;
|
text (str): the formatted prompt;
|
||||||
images (list[ImageType]): the list of images;
|
images (list[ImageType]): the list of images;
|
||||||
inference_mode (bool): if True, then remove the last eos token;
|
inference_mode (bool): if True, then remove the last eos token;
|
||||||
**kwargs:
|
**kwargs:
|
||||||
@ -278,7 +278,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
prepare = self.process_one(
|
prepare = self.process_one(
|
||||||
prompt=prompt,
|
prompt=text,
|
||||||
images=images,
|
images=images,
|
||||||
inference_mode=inference_mode,
|
inference_mode=inference_mode,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user