From 0e36abf9931baa070609376debb4fb3772f4a3fe Mon Sep 17 00:00:00 2001 From: milesial Date: Tue, 29 Jul 2025 18:16:25 -0700 Subject: [PATCH] [Bugfix] Correct max tokens for non-contiguous embeds (#21798) Signed-off-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com> Co-authored-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com> --- vllm/multimodal/profiling.py | 31 ++++++++++++++++++++++++++++--- vllm/multimodal/registry.py | 2 +- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 7f6fb47a21fa6..d96803b643ff2 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -180,11 +180,14 @@ class MultiModalProfiler(Generic[_I]): def _get_mm_num_tokens( self, mm_inputs: MultiModalInputs, + mm_embeddings_only: bool = True, ) -> Mapping[str, int]: placeholders_by_modality = mm_inputs["mm_placeholders"] return { - modality: sum(item.get_num_embeds() for item in placeholders) + modality: + sum(item.get_num_embeds() if mm_embeddings_only else item.length + for item in placeholders) for modality, placeholders in placeholders_by_modality.items() } @@ -253,10 +256,11 @@ class MultiModalProfiler(Generic[_I]): multi_modal_placeholders=mm_inputs["mm_placeholders"], ) - def get_mm_max_tokens( + def _get_mm_max_tokens( self, seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + mm_embeddings_only: bool = True, ) -> Mapping[str, int]: if mm_counts is None: mm_counts = self.get_mm_limits() @@ -285,4 +289,25 @@ class MultiModalProfiler(Generic[_I]): return max_tokens_per_item mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - return self._get_mm_num_tokens(mm_inputs) + return self._get_mm_num_tokens(mm_inputs, + mm_embeddings_only=mm_embeddings_only) + + def get_mm_max_contiguous_tokens( + self, + seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, + ): + """ + Returns the maximum length of the multimodal (image placeholders+text) + tokens, including any break/text tokens in-between image embeddings. + + [IMG] [IMG] [IMG] [IMG] [IMG] [IMG] + Returns 9, even when the number of image embeddings is 6. + + This is important to take into account when profiling and + initializing the encoder cache size. + """ + + return self._get_mm_max_tokens(seq_len, + mm_counts, + mm_embeddings_only=False) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index c44fcacd246c4..bfa391829d290 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -129,7 +129,7 @@ class MultiModalRegistry: seq_len = model_config.max_model_len mm_limits = self.get_mm_limits_per_prompt(model_config) - return profiler.get_mm_max_tokens( + return profiler.get_mm_max_contiguous_tokens( seq_len, { modality: 1