[Bugfix] Clean up multi-modal processors (#14417)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-07 18:33:38 +08:00 committed by GitHub
parent 12c29a881f
commit 05fb6718f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 79 additions and 76 deletions

View File

@ -2405,6 +2405,15 @@ class MultiModalConfig:
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
def get_limit_per_prompt(self, modality: str) -> int:
"""
Get the maximum number of input items allowed per prompt
for the given modality.
If not set by the user, this defaults to `1`.
"""
return self.limit_per_prompt.get(modality, 1)
# TODO: Add configs to init vision tower or not.

View File

@ -14,7 +14,6 @@ from einops import rearrange, repeat
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -25,8 +24,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
@ -42,8 +41,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
# The image token id may be various
_IMAGE_TOKEN = "<image>"
@ -216,30 +213,6 @@ class DeepseekVL2DummyInputsBuilder(
class DeepseekVL2MultiModalProcessor(
BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]):
def __init__(
self,
info: DeepseekVL2ProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(
info,
dummy_inputs,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
if self.cache is not None and mm_limit["image"] > 2:
# The processor output depends on the number of images passed,
# making it incompatible with processing cache which is supposed
# to be invariant of how many images are passed per prompt
self.cache = None
logger.warning_once(
f"{type(self).__name__} does not support processing cache with "
"image limit larger than 2.")
def _call_hf_processor(
self,
prompt: str,
@ -316,6 +289,31 @@ class DeepseekVL2MultiModalProcessor(
)
]
def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 2:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
@MULTIMODAL_REGISTRY.register_processor(
DeepseekVL2MultiModalProcessor,

View File

@ -8,21 +8,19 @@
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Mapping, Sequence
from typing import Optional
from typing import Optional, Union
import torch
from PIL import Image
from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .intern_vit import InternVisionModel
@ -32,8 +30,6 @@ from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
InternVLMultiModalProcessor, build_transform,
find_closest_aspect_ratio, get_internvl_target_ratios)
logger = init_logger(__name__)
def resolve_h2ovl_min_max_num(
*,
@ -465,29 +461,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
):
def __init__(self,
info: H2OVLProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(
info,
dummy_inputs,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
if self.cache is not None and mm_limit["image"] >= 2:
# The processor output depends on the number of images passed,
# making it incompatible with processing cache which is supposed
# to be invariant of how many images are passed per prompt
self.cache = None
logger.warning_once(
f"{type(self).__name__} does not support processing cache with "
"multi-image support enabled.")
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
@ -543,6 +516,31 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
)
]
def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 1:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
@MULTIMODAL_REGISTRY.register_processor(
H2OVLMultiModalProcessor,

View File

@ -133,7 +133,7 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_videos = mm_config.get_limit_per_prompt("video")
max_total_frames = self._get_max_video_frames(seq_len)

View File

@ -206,8 +206,8 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video")
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -

View File

@ -201,9 +201,9 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_audios = mm_config.limit_per_prompt.get("audio", 1)
max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video")
max_audios = mm_config.get_limit_per_prompt("audio")
# count <image_idx></image_idx> tokens
# which are not in get_max_image_tokens

View File

@ -446,8 +446,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video")
# count <image_idx></image_idx> tokens
# which are not in get_max_image_tokens

View File

@ -68,7 +68,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
image_token_id = mm_encoder.special_ids.img
mm_config = ctx.get_mm_config()
num_images = mm_config.limit_per_prompt.get("image", 1)
num_images = mm_config.get_limit_per_prompt("image")
# dummy size
size = 256

View File

@ -911,8 +911,8 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video")
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -

View File

@ -984,10 +984,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
before passing them to :meth:`_get_hf_mm_data`.
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.get_mm_config()
mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
for modality, items in mm_items.items():
limit = mm_limits.get(modality, 1)
limit = mm_config.get_limit_per_prompt(modality)
if len(items) > limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "

View File

@ -110,12 +110,10 @@ class MultiModalProfiler(Generic[_I]):
def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config()
mm_limit_per_prompt = mm_config.limit_per_prompt
supported_mm_limits = self.processing_info.get_supported_mm_limits()
mm_limits = {
modality: mm_limit_per_prompt.get(modality, 1)
modality: mm_config.get_limit_per_prompt(modality)
for modality in supported_mm_limits
}

View File

@ -355,7 +355,7 @@ class MultiModalRegistry:
# TODO: Automatically determine the limits based on budget
# once more models support multi-image inputs
limits_per_plugin = {
key: config_limits_per_plugin.get(key, 1)
key: multimodal_config.get_limit_per_prompt(key)
for key in self._plugins
}