[CI/Build] Improve Tensor Schema tests speed by avoid engine core initialization (#23357)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-09-01 13:52:20 +08:00 committed by GitHub
parent b55713683c
commit ff0e59d83a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 153 additions and 111 deletions

View File

@ -566,8 +566,7 @@ steps:
- tests/models/multimodal
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
- pytest -v -s models/multimodal/processing
- label: Multi-Modal Models Test (Standard)
mirror_hardwares: [amdexperimental]

View File

@ -1,30 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from collections.abc import Iterable
from contextlib import contextmanager
from functools import partial
from typing import Any, Union
from unittest.mock import patch
import numpy as np
import pytest
import torch.nn as nn
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
from vllm.config import ModelConfig
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import InputProcessingContext
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.engine.core import EngineCore as V1EngineCore
from vllm.utils import is_list_of
from ....conftest import VllmRunner
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
from ...utils import dummy_hf_overrides
@ -137,6 +138,27 @@ def create_batched_mm_kwargs(
return group_mm_kwargs_by_modality(items)
@contextmanager
def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(tensor_model_parallel_size=1)
vllm_config = VllmConfig(model_config=model_config)
with set_current_vllm_config(vllm_config=vllm_config):
with set_default_torch_dtype(model_config.dtype):
model = model_cls(vllm_config=vllm_config)
yield model
del model
cleanup_dist_env_and_memory()
def get_model_id_to_test(
model_arch_list: Iterable[str]) -> list[tuple[str, str]]:
filtered_results = []
@ -155,8 +177,7 @@ def get_model_id_to_test(
@pytest.mark.parametrize(
"model_arch, model_id",
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
def test_model_tensor_schema(model_arch: str, model_id: str,
vllm_runner: type[VllmRunner], monkeypatch):
def test_model_tensor_schema(model_arch: str, model_id: str):
if model_arch in ARCH_TO_SKIP:
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
if model_id in REPO_ID_TO_SKIP:
@ -177,14 +198,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
hf_overrides=hf_overrides_fn,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
if not any(
hasattr(model_cls, f"_parse_and_validate_{m}_input")
for m in ["image", "video", "audio"]):
inputs_parse_methods = []
for attr_name in dir(model_cls):
attr = getattr(model_cls, attr_name)
if hasattr(attr, "__annotations__"):
return_type = attr.__annotations__.get("return", None)
if return_type is not None and "Input" in str(return_type):
inputs_parse_methods.append(attr_name)
if not any(inputs_parse_methods):
pytest.skip(f"{model_arch} does not support tensor schema validation.")
ctx = InputProcessingContext(
@ -197,68 +224,13 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items()
}
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
processor = factories.build_processor(ctx, cache=None)
# Avoid calling model.forward()
def _initialize_kv_caches_v0(self) -> None:
self.cache_config.num_gpu_blocks = 0
self.cache_config.num_cpu_blocks = 0
def _initialize_kv_caches_v1(self, vllm_config):
kv_cache_specs = self.model_executor.get_kv_cache_specs()
scheduler_kv_cache_config = get_kv_cache_config(
vllm_config,
kv_cache_specs[0],
10 * GiB_bytes,
)
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return 1, 0, scheduler_kv_cache_config
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
_initialize_kv_caches_v0),
patch.object(V1EngineCore, "_initialize_kv_caches",
_initialize_kv_caches_v1), monkeypatch.context() as m):
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")
# TODO(Isotr0py): Can we avoid initializing engine?
with (
set_default_torch_num_threads(1),
vllm_runner(
model_id,
tokenizer_name=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
max_model_len=model_info.max_model_len,
load_format="dummy",
hf_overrides=hf_overrides_fn,
limit_mm_per_prompt=limit_mm_per_prompt,
enforce_eager=True,
) as vllm_model,
):
model_config = vllm_model.llm.llm_engine.model_config
llm_engine = vllm_model.llm.llm_engine
if hasattr(llm_engine, "processor"):
# v1 processor
mm_registry = llm_engine.processor.mm_registry
else:
# v0 input_preprocessor
mm_registry = llm_engine.input_preprocessor.mm_registry
processor = mm_registry.create_processor(model_config)
def validate_model_input(model, modality: str,
mm_kwargs: MultiModalKwargs):
method_name = f"_parse_and_validate_{modality}_input"
if hasattr(model, method_name):
getattr(model, method_name)(**mm_kwargs)
for modality, _, mm_kwargs in create_batched_mm_kwargs(
model_config, processor):
valid_func = partial(validate_model_input,
modality=modality,
mm_kwargs=mm_kwargs)
vllm_model.apply_model(valid_func)
with initialize_dummy_model(model_cls, model_config) as model:
for modality, _, mm_kwargs in create_batched_mm_kwargs(
model_config, processor):
for method_name in inputs_parse_methods:
print(f"Testing `{method_name}` with modality={modality} "
f"and mm_kwargs{list(mm_kwargs.keys())}")
getattr(model, method_name)(modality=modality, **mm_kwargs)

View File

@ -549,7 +549,7 @@ class GraniteSpeechForConditionalGeneration(
raise ValueError("Only audio modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config

View File

@ -1371,7 +1371,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
output_tensor[i, :t.size(0)] = t
return output_tensor
def _parse_and_validate_image_input(self, **kwargs: object):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[MllamaImagePixelInputs]:
# tensor with the same shape will be batched together by
# MultiModalKwargs.batch, so pixel_values here can be:
# - list[torch.Tensor]:

View File

@ -209,7 +209,7 @@ class OvisImagePatchInputs(TypedDict):
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
inducator_tokens: torch.Tensor
indicator_tokens: torch.Tensor
"""
Shape:
`(batch_size * (num_patches + 1))`

View File

@ -3,7 +3,7 @@
""" PyTorch Ovis model."""
from collections.abc import Iterable, Mapping
from functools import partial
from typing import Optional, Union
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
@ -50,6 +50,27 @@ IMAGE_PAD_TOKEN_ID_MAP = {
}
class OvisVideoPatchInputs(TypedDict):
type: Literal["video_patches"]
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
indicator_tokens: torch.Tensor
"""
Shape:
`(batch_size * (num_patches + 1))`
"""
patches_per_image: list[int]
"""
List of number of total patches for each frame in the video.
This is used to restore the first two dimensions of `flat_data`.
"""
def _ovis2_5_field_config():
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
grids=MultiModalFieldConfig.batched("image"),
@ -429,17 +450,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = (
self.get_language_model().make_empty_intermediate_tensors)
def _parse_and_validate_visual_input(
self, is_video,
**kwargs: object) -> Optional[OvisImagePatchInputs]:
if is_video:
pixel_values = kwargs.pop("video_pixel_values", None)
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
grids = kwargs.pop("video_grids", None)
else:
pixel_values = kwargs.pop("pixel_values", None)
indicator_tokens = kwargs.pop("indicator_tokens", None)
grids = kwargs.pop("grids", None)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
pixel_values = kwargs.pop("pixel_values", None)
indicator_tokens = kwargs.pop("indicator_tokens", None)
grids = kwargs.pop("grids", None)
if pixel_values is None and indicator_tokens is None:
return None
@ -466,8 +481,40 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
pixel_values = kwargs.pop("video_pixel_values", None)
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
grids = kwargs.pop("video_grids", None)
if pixel_values is None and indicator_tokens is None:
return None
if pixel_values is not None and indicator_tokens is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(indicator_tokens, (torch.Tensor, list)):
raise ValueError("Incorrect type of indicator_tokens. "
f"Got type: {type(indicator_tokens)}")
return OvisVideoPatchInputs(
type="video_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
patches_per_image=[
x.shape[0] // (self.config.vit_config.hidden_stride**2)
for x in flatten_bn(pixel_values)
],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
concat=True),
grids=flatten_bn(flatten_bn(grids), concat=True),
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]
indicator_tokens = image_input["indicator_tokens"]
@ -500,21 +547,44 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
torch.cat(vision_embeddings_per_image, dim=0))
return tuple(vision_embeddings)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
embeddings = []
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# NOTE: _parse_and_validate_visual_input has side-effects and pops
# keys from kwargs. We process images first, then videos.
image_input = self._parse_and_validate_visual_input(False, **kwargs)
if image_input:
embeddings.extend(self._process_image_input(image_input))
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values", "indicator_tokens",
"grids") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key in ("video_pixel_values", "video_indicator_tokens",
"video_grids") and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
video_input = self._parse_and_validate_visual_input(True, **kwargs)
if video_input:
embeddings.extend(self._process_image_input(video_input))
return modalities
return tuple(embeddings) if embeddings else None
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,

View File

@ -1031,8 +1031,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
]
return audio_embeds
def _parse_and_validate_image_input(self,
**kwargs: object) -> Optional[dict]:
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]:
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
if input_image_embeds is None:
return None