[Bugfix] Fix mm_limits access for merged multi-modal processor (#12252)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-21 18:09:39 +08:00 committed by GitHub
parent f2e9f2a3be
commit a94eee4456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 7 deletions

View File

@ -106,7 +106,7 @@ class MultiModalProfiler(Generic[_I]):
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]: def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
return self.processor.dummy_inputs return self.processor.dummy_inputs
def _get_mm_limits(self) -> Mapping[str, int]: def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config() mm_config = self.processing_info.ctx.get_mm_config()
mm_limit_per_prompt = mm_config.limit_per_prompt mm_limit_per_prompt = mm_config.limit_per_prompt
@ -146,7 +146,7 @@ class MultiModalProfiler(Generic[_I]):
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
mm_counts = self._get_mm_limits() mm_counts = self.get_mm_limits()
info = self.processing_info info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len) mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len)

View File

@ -17,7 +17,7 @@ from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
ProcessingCache) ProcessingCache)
from .profiling import BaseDummyInputsBuilder from .profiling import BaseDummyInputsBuilder, MultiModalProfiler
from .utils import cached_get_tokenizer from .utils import cached_get_tokenizer
from .video import VideoPlugin from .video import VideoPlugin
@ -282,13 +282,13 @@ class MultiModalRegistry:
This is currently directly used only in V1 for profiling the memory This is currently directly used only in V1 for profiling the memory
usage of a model. usage of a model.
""" """
limits_per_plugin = self._limits_by_model[model_config] mm_limits = self.get_mm_limits_per_prompt(model_config)
return { return {
key: max_tokens_per_mm_item key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items() self.get_max_tokens_per_item_by_modality(model_config).items()
if limits_per_plugin[key] > 0 if mm_limits[key] > 0
} }
def get_max_tokens_by_modality( def get_max_tokens_by_modality(
@ -304,10 +304,10 @@ class MultiModalRegistry:
Note: Note:
This should be called after :meth:`init_mm_limits_per_prompt`. This should be called after :meth:`init_mm_limits_per_prompt`.
""" """
limits_per_plugin = self._limits_by_model[model_config] mm_limits = self.get_mm_limits_per_prompt(model_config)
return { return {
key: limits_per_plugin[key] * max_tokens_per_mm_item key: mm_limits[key] * max_tokens_per_mm_item
for key, max_tokens_per_mm_item in for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items() self.get_max_tokens_per_item_by_modality(model_config).items()
} }
@ -371,6 +371,15 @@ class MultiModalRegistry:
Note: Note:
This should be called after :meth:`init_mm_limits_per_prompt`. This should be called after :meth:`init_mm_limits_per_prompt`.
""" """
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
processor = self.create_processor(model_config, tokenizer)
profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()
return self._limits_by_model[model_config] return self._limits_by_model[model_config]
def register_processor( def register_processor(