[Misc] Simplify max tokens in multimodal registry (#27500)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-25 14:56:01 +08:00 committed by GitHub
parent b853540388
commit 4c5f632165
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 17 additions and 38 deletions

View File

@ -355,7 +355,11 @@ class MultiModalProfiler(Generic[_I]):
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return max_tokens_per_item
return {
modality: max_tokens
for modality, max_tokens in max_tokens_per_item.items()
if mm_counts.get(modality, 0) > 0
}
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
@ -375,5 +379,4 @@ class MultiModalProfiler(Generic[_I]):
This is important to take into account when profiling and
initializing the encoder cache size.
"""
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)

View File

@ -152,6 +152,7 @@ class MultiModalRegistry:
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
@ -164,40 +165,15 @@ class MultiModalRegistry:
profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
profiler_limits = (
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
)
return profiler.get_mm_max_contiguous_tokens(
seq_len,
{modality: 1 for modality, limit in mm_limits.items() if limit > 0},
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
)
def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.
Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
)
return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
if mm_limits[key] > 0
}
def get_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
@ -369,7 +345,7 @@ class MultiModalRegistry:
"""
if not model_config.is_encoder_decoder:
return 0
max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config)
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
if not max_tokens:
# TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more

View File

@ -264,8 +264,8 @@ def compute_encoder_budget(
from the input sequence.
"""
if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = (
mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config)
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config
)
return compute_mm_encoder_budget(

View File

@ -42,10 +42,10 @@ class MultiModalBudget:
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_by_modality = (
mm_registry.get_max_tokens_per_item_by_nonzero_modality(
model_config, cache=cache
)
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
profiler_limits=self.mm_limits,
)
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(