mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
[CI/Build] Improve Tensor Schema tests speed by avoid engine core initialization (#23357)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
b55713683c
commit
ff0e59d83a
@ -566,8 +566,7 @@ steps:
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
||||
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
|
||||
- pytest -v -s models/multimodal/processing
|
||||
|
||||
- label: Multi-Modal Models Test (Standard)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
|
||||
@ -1,30 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import tempfile
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Any, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
||||
UserMessage)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs)
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
||||
from ...utils import dummy_hf_overrides
|
||||
|
||||
@ -137,6 +138,27 @@ def create_batched_mm_kwargs(
|
||||
return group_mm_kwargs_by_modality(items)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig):
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend="nccl",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
with set_current_vllm_config(vllm_config=vllm_config):
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
model = model_cls(vllm_config=vllm_config)
|
||||
yield model
|
||||
|
||||
del model
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def get_model_id_to_test(
|
||||
model_arch_list: Iterable[str]) -> list[tuple[str, str]]:
|
||||
filtered_results = []
|
||||
@ -155,8 +177,7 @@ def get_model_id_to_test(
|
||||
@pytest.mark.parametrize(
|
||||
"model_arch, model_id",
|
||||
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
|
||||
def test_model_tensor_schema(model_arch: str, model_id: str,
|
||||
vllm_runner: type[VllmRunner], monkeypatch):
|
||||
def test_model_tensor_schema(model_arch: str, model_id: str):
|
||||
if model_arch in ARCH_TO_SKIP:
|
||||
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
|
||||
if model_id in REPO_ID_TO_SKIP:
|
||||
@ -177,14 +198,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
hf_overrides=hf_overrides_fn,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
|
||||
if not any(
|
||||
hasattr(model_cls, f"_parse_and_validate_{m}_input")
|
||||
for m in ["image", "video", "audio"]):
|
||||
inputs_parse_methods = []
|
||||
for attr_name in dir(model_cls):
|
||||
attr = getattr(model_cls, attr_name)
|
||||
if hasattr(attr, "__annotations__"):
|
||||
return_type = attr.__annotations__.get("return", None)
|
||||
if return_type is not None and "Input" in str(return_type):
|
||||
inputs_parse_methods.append(attr_name)
|
||||
|
||||
if not any(inputs_parse_methods):
|
||||
pytest.skip(f"{model_arch} does not support tensor schema validation.")
|
||||
|
||||
ctx = InputProcessingContext(
|
||||
@ -197,68 +224,13 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
|
||||
modality: 3 if limit is None else limit
|
||||
for modality, limit in supported_mm_limits.items()
|
||||
}
|
||||
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
|
||||
processor = factories.build_processor(ctx, cache=None)
|
||||
|
||||
# Avoid calling model.forward()
|
||||
def _initialize_kv_caches_v0(self) -> None:
|
||||
self.cache_config.num_gpu_blocks = 0
|
||||
self.cache_config.num_cpu_blocks = 0
|
||||
|
||||
def _initialize_kv_caches_v1(self, vllm_config):
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
scheduler_kv_cache_config = get_kv_cache_config(
|
||||
vllm_config,
|
||||
kv_cache_specs[0],
|
||||
10 * GiB_bytes,
|
||||
)
|
||||
|
||||
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
|
||||
return 1, 0, scheduler_kv_cache_config
|
||||
|
||||
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
|
||||
_initialize_kv_caches_v0),
|
||||
patch.object(V1EngineCore, "_initialize_kv_caches",
|
||||
_initialize_kv_caches_v1), monkeypatch.context() as m):
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
if model_info.v0_only:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
# TODO(Isotr0py): Can we avoid initializing engine?
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
model_id,
|
||||
tokenizer_name=model_info.tokenizer,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
max_model_len=model_info.max_model_len,
|
||||
load_format="dummy",
|
||||
hf_overrides=hf_overrides_fn,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
enforce_eager=True,
|
||||
) as vllm_model,
|
||||
):
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
llm_engine = vllm_model.llm.llm_engine
|
||||
|
||||
if hasattr(llm_engine, "processor"):
|
||||
# v1 processor
|
||||
mm_registry = llm_engine.processor.mm_registry
|
||||
else:
|
||||
# v0 input_preprocessor
|
||||
mm_registry = llm_engine.input_preprocessor.mm_registry
|
||||
|
||||
processor = mm_registry.create_processor(model_config)
|
||||
|
||||
def validate_model_input(model, modality: str,
|
||||
mm_kwargs: MultiModalKwargs):
|
||||
method_name = f"_parse_and_validate_{modality}_input"
|
||||
if hasattr(model, method_name):
|
||||
getattr(model, method_name)(**mm_kwargs)
|
||||
|
||||
for modality, _, mm_kwargs in create_batched_mm_kwargs(
|
||||
model_config, processor):
|
||||
valid_func = partial(validate_model_input,
|
||||
modality=modality,
|
||||
mm_kwargs=mm_kwargs)
|
||||
vllm_model.apply_model(valid_func)
|
||||
with initialize_dummy_model(model_cls, model_config) as model:
|
||||
for modality, _, mm_kwargs in create_batched_mm_kwargs(
|
||||
model_config, processor):
|
||||
for method_name in inputs_parse_methods:
|
||||
print(f"Testing `{method_name}` with modality={modality} "
|
||||
f"and mm_kwargs{list(mm_kwargs.keys())}")
|
||||
getattr(model, method_name)(modality=modality, **mm_kwargs)
|
||||
|
||||
@ -549,7 +549,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
|
||||
raise ValueError("Only audio modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
@ -1371,7 +1371,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
output_tensor[i, :t.size(0)] = t
|
||||
return output_tensor
|
||||
|
||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[MllamaImagePixelInputs]:
|
||||
# tensor with the same shape will be batched together by
|
||||
# MultiModalKwargs.batch, so pixel_values here can be:
|
||||
# - list[torch.Tensor]:
|
||||
|
||||
@ -209,7 +209,7 @@ class OvisImagePatchInputs(TypedDict):
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
"""
|
||||
|
||||
inducator_tokens: torch.Tensor
|
||||
indicator_tokens: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * (num_patches + 1))`
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
""" PyTorch Ovis model."""
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -50,6 +50,27 @@ IMAGE_PAD_TOKEN_ID_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class OvisVideoPatchInputs(TypedDict):
|
||||
type: Literal["video_patches"]
|
||||
flat_data: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
"""
|
||||
|
||||
indicator_tokens: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * (num_patches + 1))`
|
||||
"""
|
||||
|
||||
patches_per_image: list[int]
|
||||
"""
|
||||
List of number of total patches for each frame in the video.
|
||||
This is used to restore the first two dimensions of `flat_data`.
|
||||
"""
|
||||
|
||||
|
||||
def _ovis2_5_field_config():
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
grids=MultiModalFieldConfig.batched("image"),
|
||||
@ -429,17 +450,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.get_language_model().make_empty_intermediate_tensors)
|
||||
|
||||
def _parse_and_validate_visual_input(
|
||||
self, is_video,
|
||||
**kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||
if is_video:
|
||||
pixel_values = kwargs.pop("video_pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
|
||||
grids = kwargs.pop("video_grids", None)
|
||||
else:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
||||
grids = kwargs.pop("grids", None)
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
||||
grids = kwargs.pop("grids", None)
|
||||
if pixel_values is None and indicator_tokens is None:
|
||||
return None
|
||||
|
||||
@ -466,8 +481,40 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||
pixel_values = kwargs.pop("video_pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
|
||||
grids = kwargs.pop("video_grids", None)
|
||||
if pixel_values is None and indicator_tokens is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None and indicator_tokens is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(indicator_tokens, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of indicator_tokens. "
|
||||
f"Got type: {type(indicator_tokens)}")
|
||||
|
||||
return OvisVideoPatchInputs(
|
||||
type="video_patches",
|
||||
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
||||
patches_per_image=[
|
||||
x.shape[0] // (self.config.vit_config.hidden_stride**2)
|
||||
for x in flatten_bn(pixel_values)
|
||||
],
|
||||
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
|
||||
concat=True),
|
||||
grids=flatten_bn(flatten_bn(grids), concat=True),
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
|
||||
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
|
||||
) -> MultiModalEmbeddings:
|
||||
image_patches_flat = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"]
|
||||
indicator_tokens = image_input["indicator_tokens"]
|
||||
@ -500,21 +547,44 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
torch.cat(vision_embeddings_per_image, dim=0))
|
||||
return tuple(vision_embeddings)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
embeddings = []
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
|
||||
# NOTE: _parse_and_validate_visual_input has side-effects and pops
|
||||
# keys from kwargs. We process images first, then videos.
|
||||
image_input = self._parse_and_validate_visual_input(False, **kwargs)
|
||||
if image_input:
|
||||
embeddings.extend(self._process_image_input(image_input))
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values", "indicator_tokens",
|
||||
"grids") and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if input_key in ("video_pixel_values", "video_indicator_tokens",
|
||||
"video_grids") and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
|
||||
video_input = self._parse_and_validate_visual_input(True, **kwargs)
|
||||
if video_input:
|
||||
embeddings.extend(self._process_image_input(video_input))
|
||||
return modalities
|
||||
|
||||
return tuple(embeddings) if embeddings else None
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return []
|
||||
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_image_input(video_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
|
||||
@ -1031,8 +1031,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
]
|
||||
return audio_embeds
|
||||
|
||||
def _parse_and_validate_image_input(self,
|
||||
**kwargs: object) -> Optional[dict]:
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]:
|
||||
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
|
||||
if input_image_embeds is None:
|
||||
return None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user