[Bugfix] Correct max tokens for non-contiguous embeds (#21798)

Signed-off-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com>
Co-authored-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com>
This commit is contained in:
milesial 2025-07-29 18:16:25 -07:00 committed by GitHub
parent 452b2a3180
commit 0e36abf993
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 4 deletions

View File

@ -180,11 +180,14 @@ class MultiModalProfiler(Generic[_I]):
def _get_mm_num_tokens(
self,
mm_inputs: MultiModalInputs,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"]
return {
modality: sum(item.get_num_embeds() for item in placeholders)
modality:
sum(item.get_num_embeds() if mm_embeddings_only else item.length
for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
@ -253,10 +256,11 @@ class MultiModalProfiler(Generic[_I]):
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
def get_mm_max_tokens(
def _get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
if mm_counts is None:
mm_counts = self.get_mm_limits()
@ -285,4 +289,25 @@ class MultiModalProfiler(Generic[_I]):
return max_tokens_per_item
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs)
return self._get_mm_num_tokens(mm_inputs,
mm_embeddings_only=mm_embeddings_only)
def get_mm_max_contiguous_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
):
"""
Returns the maximum length of the multimodal (image placeholders+text)
tokens, including any break/text tokens in-between image embeddings.
<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>
Returns 9, even when the number of image embeddings is 6.
This is important to take into account when profiling and
initializing the encoder cache size.
"""
return self._get_mm_max_tokens(seq_len,
mm_counts,
mm_embeddings_only=False)

View File

@ -129,7 +129,7 @@ class MultiModalRegistry:
seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config)
return profiler.get_mm_max_tokens(
return profiler.get_mm_max_contiguous_tokens(
seq_len,
{
modality: 1