[Misc] Simplify max tokens in multimodal registry (#27500)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-25 14:56:01 +08:00 committed by GitHub
parent b853540388
commit 4c5f632165
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 17 additions and 38 deletions

View File

@ -355,7 +355,11 @@ class MultiModalProfiler(Generic[_I]):
mm_counts=mm_counts, mm_counts=mm_counts,
) )
if max_tokens_per_item is not None: if max_tokens_per_item is not None:
return max_tokens_per_item return {
modality: max_tokens
for modality, max_tokens in max_tokens_per_item.items()
if mm_counts.get(modality, 0) > 0
}
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only) return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
@ -375,5 +379,4 @@ class MultiModalProfiler(Generic[_I]):
This is important to take into account when profiling and This is important to take into account when profiling and
initializing the encoder cache size. initializing the encoder cache size.
""" """
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False) return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)

View File

@ -152,6 +152,7 @@ class MultiModalRegistry:
model_config: "ModelConfig", model_config: "ModelConfig",
*, *,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None,
) -> Mapping[str, int]: ) -> Mapping[str, int]:
""" """
Get the maximum number of tokens per data item from each modality based Get the maximum number of tokens per data item from each modality based
@ -164,40 +165,15 @@ class MultiModalRegistry:
profiler: MultiModalProfiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) profiler_limits = (
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
)
return profiler.get_mm_max_contiguous_tokens( return profiler.get_mm_max_contiguous_tokens(
seq_len, seq_len,
{modality: 1 for modality, limit in mm_limits.items() if limit > 0}, {modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
) )
def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.
Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
)
return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
if mm_limits[key] > 0
}
def get_mm_limits_per_prompt( def get_mm_limits_per_prompt(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
@ -369,7 +345,7 @@ class MultiModalRegistry:
""" """
if not model_config.is_encoder_decoder: if not model_config.is_encoder_decoder:
return 0 return 0
max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config) max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
if not max_tokens: if not max_tokens:
# TODO - this function assumes encoder-decoder models are # TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more # multimodal. This will need to change when adding support for more

View File

@ -264,8 +264,8 @@ def compute_encoder_budget(
from the input sequence. from the input sequence.
""" """
if mm_registry.supports_multimodal_inputs(model_config): if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = ( max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config) model_config
) )
return compute_mm_encoder_budget( return compute_mm_encoder_budget(

View File

@ -42,10 +42,10 @@ class MultiModalBudget:
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache) self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_by_modality = ( max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
mm_registry.get_max_tokens_per_item_by_nonzero_modality( model_config,
model_config, cache=cache cache=cache,
) profiler_limits=self.mm_limits,
) )
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(