mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
307 lines
10 KiB
Python
307 lines
10 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from collections.abc import Mapping
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
|
|
|
|
import torch.nn as nn
|
|
|
|
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
|
from vllm.inputs import InputProcessingContext
|
|
from vllm.logger import init_logger
|
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
|
cached_tokenizer_from_config)
|
|
from vllm.utils import ClassRegistry
|
|
|
|
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
|
ProcessingCache)
|
|
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
|
|
DummyEncoderData, MultiModalProfiler)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import ModelConfig
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
N = TypeVar("N", bound=type[nn.Module])
|
|
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
|
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
|
|
|
|
|
|
class ProcessingInfoFactory(Protocol[_I_co]):
|
|
"""
|
|
Constructs a
|
|
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
|
instance from the context.
|
|
"""
|
|
|
|
def __call__(
|
|
self,
|
|
ctx: InputProcessingContext,
|
|
) -> _I_co:
|
|
...
|
|
|
|
|
|
class DummyInputsBuilderFactory(Protocol[_I]):
|
|
"""
|
|
Constructs a
|
|
[`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
|
|
instance from the context.
|
|
"""
|
|
|
|
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
|
|
...
|
|
|
|
|
|
class MultiModalProcessorFactory(Protocol[_I]):
|
|
"""
|
|
Constructs a
|
|
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
|
instance from the context.
|
|
"""
|
|
|
|
def __call__(
|
|
self,
|
|
info: _I,
|
|
dummy_inputs: BaseDummyInputsBuilder[_I],
|
|
*,
|
|
cache: Optional[ProcessingCache] = None,
|
|
) -> BaseMultiModalProcessor[_I]:
|
|
...
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _ProcessorFactories(Generic[_I]):
|
|
info: ProcessingInfoFactory[_I]
|
|
processor: MultiModalProcessorFactory[_I]
|
|
dummy_inputs: DummyInputsBuilderFactory[_I]
|
|
|
|
def build_processor(
|
|
self,
|
|
ctx: InputProcessingContext,
|
|
*,
|
|
cache: Optional[ProcessingCache] = None,
|
|
):
|
|
info = self.info(ctx)
|
|
dummy_inputs_builder = self.dummy_inputs(info)
|
|
return self.processor(info, dummy_inputs_builder, cache=cache)
|
|
|
|
|
|
class MultiModalRegistry:
|
|
"""
|
|
A registry that dispatches data processing according to the model.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._processor_factories = ClassRegistry[nn.Module,
|
|
_ProcessorFactories]()
|
|
|
|
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
|
|
|
|
def reset_processor_cache(self) -> bool:
|
|
"""Reset the multi-modal processing cache."""
|
|
self._processing_cache.reset()
|
|
|
|
return True # Success
|
|
|
|
def get_max_tokens_per_item_by_modality(
|
|
self,
|
|
model_config: "ModelConfig",
|
|
) -> Mapping[str, int]:
|
|
"""
|
|
Get the maximum number of tokens per data item from each modality based
|
|
on underlying model configuration.
|
|
"""
|
|
if not model_config.is_multimodal_model:
|
|
return {}
|
|
|
|
processor = self.create_processor(model_config, disable_cache=False)
|
|
profiler = MultiModalProfiler(processor)
|
|
|
|
seq_len = model_config.max_model_len
|
|
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
|
|
|
return profiler.get_mm_max_contiguous_tokens(
|
|
seq_len,
|
|
{
|
|
modality: 1
|
|
for modality, limit in mm_limits.items() if limit > 0
|
|
},
|
|
)
|
|
|
|
def get_max_tokens_per_item_by_nonzero_modality(
|
|
self,
|
|
model_config: "ModelConfig",
|
|
) -> Mapping[str, int]:
|
|
"""
|
|
Get the maximum number of tokens per data item from each modality based
|
|
on underlying model configuration, excluding modalities that user
|
|
explicitly disabled via `limit_mm_per_prompt`.
|
|
|
|
Note:
|
|
This is currently directly used only in V1 for profiling the memory
|
|
usage of a model.
|
|
"""
|
|
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
|
|
|
return {
|
|
key: max_tokens_per_mm_item
|
|
for key, max_tokens_per_mm_item in
|
|
self.get_max_tokens_per_item_by_modality(model_config).items()
|
|
if mm_limits[key] > 0
|
|
}
|
|
|
|
def get_max_tokens_by_modality(
|
|
self,
|
|
model_config: "ModelConfig",
|
|
) -> Mapping[str, int]:
|
|
"""
|
|
Get the maximum number of tokens from each modality
|
|
for profiling the memory usage of a model.
|
|
"""
|
|
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
|
|
|
return {
|
|
key: mm_limits[key] * max_tokens_per_mm_item
|
|
for key, max_tokens_per_mm_item in
|
|
self.get_max_tokens_per_item_by_modality(model_config).items()
|
|
}
|
|
|
|
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
|
|
"""
|
|
Get the maximum number of multi-modal tokens
|
|
for profiling the memory usage of a model.
|
|
"""
|
|
return sum(self.get_max_tokens_by_modality(model_config).values())
|
|
|
|
def get_mm_limits_per_prompt(
|
|
self,
|
|
model_config: "ModelConfig",
|
|
) -> Mapping[str, int]:
|
|
"""
|
|
Get the maximum number of multi-modal input instances for each modality
|
|
that are allowed per prompt for a model class.
|
|
"""
|
|
if not model_config.is_multimodal_model:
|
|
return {}
|
|
|
|
processor = self.create_processor(model_config, disable_cache=False)
|
|
profiler = MultiModalProfiler(processor)
|
|
return profiler.get_mm_limits()
|
|
|
|
def register_processor(
|
|
self,
|
|
processor: MultiModalProcessorFactory[_I],
|
|
*,
|
|
info: ProcessingInfoFactory[_I],
|
|
dummy_inputs: DummyInputsBuilderFactory[_I],
|
|
):
|
|
"""
|
|
Register a multi-modal processor to a model class. The processor
|
|
is constructed lazily, hence a factory method should be passed.
|
|
|
|
When the model receives multi-modal data, the provided function is
|
|
invoked to transform the data into a dictionary of model inputs.
|
|
"""
|
|
|
|
def wrapper(model_cls: N) -> N:
|
|
if self._processor_factories.contains(model_cls, strict=True):
|
|
logger.warning(
|
|
"Model class %s already has a multi-modal processor "
|
|
"registered to %s. It is overwritten by the new one.",
|
|
model_cls, self)
|
|
|
|
self._processor_factories[model_cls] = _ProcessorFactories(
|
|
info=info,
|
|
dummy_inputs=dummy_inputs,
|
|
processor=processor,
|
|
)
|
|
|
|
return model_cls
|
|
|
|
return wrapper
|
|
|
|
def _get_model_cls(self, model_config: "ModelConfig"):
|
|
# Avoid circular import
|
|
from vllm.model_executor.model_loader import get_model_architecture
|
|
|
|
model_cls, _ = get_model_architecture(model_config)
|
|
return model_cls
|
|
|
|
def create_processor(
|
|
self,
|
|
model_config: "ModelConfig",
|
|
*,
|
|
tokenizer: Optional[AnyTokenizer] = None,
|
|
disable_cache: Optional[bool] = None,
|
|
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
|
"""
|
|
Create a multi-modal processor for a specific model and tokenizer.
|
|
"""
|
|
if not model_config.is_multimodal_model:
|
|
raise ValueError(f"{model_config.model} is not a multimodal model")
|
|
|
|
if tokenizer is None and not model_config.skip_tokenizer_init:
|
|
tokenizer = cached_tokenizer_from_config(model_config)
|
|
if disable_cache is None:
|
|
mm_config = model_config.get_multimodal_config()
|
|
disable_cache = mm_config.disable_mm_preprocessor_cache
|
|
|
|
model_cls = self._get_model_cls(model_config)
|
|
factories = self._processor_factories[model_cls]
|
|
|
|
ctx = InputProcessingContext(model_config, tokenizer)
|
|
cache = None if disable_cache else self._processing_cache
|
|
|
|
return factories.build_processor(ctx, cache=cache)
|
|
|
|
def get_decoder_dummy_data(
|
|
self,
|
|
model_config: "ModelConfig",
|
|
seq_len: int,
|
|
mm_counts: Optional[Mapping[str, int]] = None,
|
|
) -> DummyDecoderData:
|
|
"""
|
|
Create dummy data for profiling the memory usage of a model.
|
|
|
|
The model is identified by ``model_config``.
|
|
"""
|
|
processor = self.create_processor(model_config, disable_cache=False)
|
|
profiler = MultiModalProfiler(processor)
|
|
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
|
|
|
|
# Having more tokens is over-conservative but otherwise fine
|
|
token_ids = dummy_data.prompt_token_ids
|
|
if len(token_ids) < seq_len:
|
|
raise AssertionError(
|
|
f"Expected at least {seq_len} dummy tokens for profiling, "
|
|
f"but found {len(token_ids)} tokens instead.")
|
|
|
|
return dummy_data
|
|
|
|
def get_encoder_dummy_data(
|
|
self,
|
|
model_config: "ModelConfig",
|
|
seq_len: int,
|
|
mm_counts: Optional[Mapping[str, int]] = None,
|
|
) -> DummyEncoderData:
|
|
"""
|
|
Create dummy data for profiling the memory usage of a model.
|
|
|
|
The model is identified by ``model_config``.
|
|
"""
|
|
processor = self.create_processor(model_config, disable_cache=False)
|
|
profiler = MultiModalProfiler(processor)
|
|
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
|
|
|
|
# Having more tokens is over-conservative but otherwise fine
|
|
token_ids = dummy_data.prompt_token_ids
|
|
if len(token_ids) < seq_len:
|
|
logger.warning_once(
|
|
"Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.", # noqa: E501
|
|
seq_len,
|
|
len(token_ids),
|
|
)
|
|
|
|
return dummy_data
|