mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:44:30 +08:00
[Bugfix] Correct max tokens for non-contiguous embeds (#21798)
Signed-off-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com> Co-authored-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com>
This commit is contained in:
parent
452b2a3180
commit
0e36abf993
@ -180,11 +180,14 @@ class MultiModalProfiler(Generic[_I]):
|
||||
def _get_mm_num_tokens(
|
||||
self,
|
||||
mm_inputs: MultiModalInputs,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
return {
|
||||
modality: sum(item.get_num_embeds() for item in placeholders)
|
||||
modality:
|
||||
sum(item.get_num_embeds() if mm_embeddings_only else item.length
|
||||
for item in placeholders)
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
|
||||
@ -253,10 +256,11 @@ class MultiModalProfiler(Generic[_I]):
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
def get_mm_max_tokens(
|
||||
def _get_mm_max_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
@ -285,4 +289,25 @@ class MultiModalProfiler(Generic[_I]):
|
||||
return max_tokens_per_item
|
||||
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
return self._get_mm_num_tokens(mm_inputs)
|
||||
return self._get_mm_num_tokens(mm_inputs,
|
||||
mm_embeddings_only=mm_embeddings_only)
|
||||
|
||||
def get_mm_max_contiguous_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
):
|
||||
"""
|
||||
Returns the maximum length of the multimodal (image placeholders+text)
|
||||
tokens, including any break/text tokens in-between image embeddings.
|
||||
|
||||
<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>
|
||||
Returns 9, even when the number of image embeddings is 6.
|
||||
|
||||
This is important to take into account when profiling and
|
||||
initializing the encoder cache size.
|
||||
"""
|
||||
|
||||
return self._get_mm_max_tokens(seq_len,
|
||||
mm_counts,
|
||||
mm_embeddings_only=False)
|
||||
|
||||
@ -129,7 +129,7 @@ class MultiModalRegistry:
|
||||
seq_len = model_config.max_model_len
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||
|
||||
return profiler.get_mm_max_tokens(
|
||||
return profiler.get_mm_max_contiguous_tokens(
|
||||
seq_len,
|
||||
{
|
||||
modality: 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user