vllm/vllm/multimodal/profiling.py
Cyrus Leung 686cfd91e3 [mypy] Further improve MM type annotations (#25654)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-03 13:35:55 -07:00

285 lines
9.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast
import numpy as np
import numpy.typing as npt
from PIL import Image
import vllm.envs as envs
from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalKwargsItems,
MultiModalPlaceholderDict)
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
EncDecMultiModalProcessor)
logger = init_logger(__name__)
@dataclass
class ProcessorInputs:
"""
Represents the keyword arguments to
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
"""
prompt: Union[str, list[int]]
mm_data: MultiModalDataDict
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
class DummyEncoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargsItems
multi_modal_placeholders: MultiModalPlaceholderDict
_I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseDummyInputsBuilder(ABC, Generic[_I]):
"""
Abstract base class that constructs the dummy data to profile
multi-modal models.
"""
def __init__(self, info: _I) -> None:
super().__init__()
self.info = info
@abstractmethod
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
"""
Build the text input corresponding to `mm_counts`.
"""
raise NotImplementedError
@abstractmethod
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
"""
Build the multimodal input which, after processing, results in
the maximum possible number of placeholder tokens.
"""
raise NotImplementedError
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
"""
Build the input which, after processing, results in
the maximum possible number of placeholder tokens.
"""
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
tokenization_kwargs = {"truncation": False}
return ProcessorInputs(prompt=dummy_text,
mm_data=dummy_mm_data,
tokenization_kwargs=tokenization_kwargs)
def _get_dummy_audios(
self,
*,
length: int,
num_audios: int,
) -> list[npt.NDArray]:
if num_audios == 0:
return []
audio = np.zeros((length, ))
return [audio] * num_audios
def _get_dummy_images(
self,
*,
width: int,
height: int,
num_images: int,
) -> list[Image.Image]:
if num_images == 0:
return []
image = Image.new("RGB", (width, height), color=255)
return [image] * num_images
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
) -> list[npt.NDArray]:
if num_videos == 0:
return []
video = np.full((num_frames, width, height, 3), 255)
return [video] * num_videos
class MultiModalProfiler(Generic[_I]):
"""
Contains code for running memory profiling for multi-modal models.
"""
def __init__(
self,
processor: BaseMultiModalProcessor[_I],
) -> None:
super().__init__()
self.processor = processor
@property
def processing_info(self) -> BaseProcessingInfo:
return self.processor.info
@property
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
return self.processor.dummy_inputs
def get_mm_limits(self) -> Mapping[str, int]:
return self.processor.allowed_mm_limits
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> MultiModalInputs:
if mm_counts is None:
mm_counts = self.get_mm_limits()
factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts)
return self.processor.apply(
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)
def _get_mm_num_tokens(
self,
mm_inputs: MultiModalInputs,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"]
return {
modality:
sum(item.get_num_embeds() if mm_embeddings_only else item.length
for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
def get_encoder_dummy_data(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyEncoderData:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
total_len = len(encoder_prompt_token_ids)
processor = cast(EncDecMultiModalProcessor, self.processor)
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
# NOTE: Whisper allows total_len > seq_len.
elif total_len > seq_len and not envs.VLLM_USE_V1:
# `max_num_batched_tokens` is defined by `SchedulerConfig`
logger.warning_once(
"The encoder sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) " # noqa: E501
"is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). " # noqa: E501
"This may cause certain multi-modal inputs to fail during inference, even when the input text is short. " # noqa: E501
"To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.", # noqa: E501
seq_len,
total_len,
str(self._get_mm_num_tokens(mm_inputs)),
)
return DummyEncoderData(encoder_prompt_token_ids)
def get_decoder_dummy_data(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyDecoderData:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
if total_len < seq_len:
prompt_token_ids.extend([0] * (seq_len - total_len))
return DummyDecoderData(
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
def _get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
if mm_counts is None:
mm_counts = self.get_mm_limits()
max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
seq_len=seq_len,
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return max_tokens_per_item
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs,
mm_embeddings_only=mm_embeddings_only)
def get_mm_max_contiguous_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
):
"""
Returns the maximum length of the multimodal (image placeholders+text)
tokens, including any break/text tokens in-between image embeddings.
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
Returns 9, even when the number of image embeddings is 6.
This is important to take into account when profiling and
initializing the encoder cache size.
"""
return self._get_mm_max_tokens(seq_len,
mm_counts,
mm_embeddings_only=False)