mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 09:11:18 +08:00
[Bugfix] Correct max tokens for non-contiguous embeds (#21798)
Signed-off-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com> Co-authored-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com>
This commit is contained in:
parent
452b2a3180
commit
0e36abf993
@ -180,11 +180,14 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
def _get_mm_num_tokens(
|
def _get_mm_num_tokens(
|
||||||
self,
|
self,
|
||||||
mm_inputs: MultiModalInputs,
|
mm_inputs: MultiModalInputs,
|
||||||
|
mm_embeddings_only: bool = True,
|
||||||
) -> Mapping[str, int]:
|
) -> Mapping[str, int]:
|
||||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
modality: sum(item.get_num_embeds() for item in placeholders)
|
modality:
|
||||||
|
sum(item.get_num_embeds() if mm_embeddings_only else item.length
|
||||||
|
for item in placeholders)
|
||||||
for modality, placeholders in placeholders_by_modality.items()
|
for modality, placeholders in placeholders_by_modality.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -253,10 +256,11 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_mm_max_tokens(
|
def _get_mm_max_tokens(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Optional[Mapping[str, int]] = None,
|
mm_counts: Optional[Mapping[str, int]] = None,
|
||||||
|
mm_embeddings_only: bool = True,
|
||||||
) -> Mapping[str, int]:
|
) -> Mapping[str, int]:
|
||||||
if mm_counts is None:
|
if mm_counts is None:
|
||||||
mm_counts = self.get_mm_limits()
|
mm_counts = self.get_mm_limits()
|
||||||
@ -285,4 +289,25 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
return max_tokens_per_item
|
return max_tokens_per_item
|
||||||
|
|
||||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||||
return self._get_mm_num_tokens(mm_inputs)
|
return self._get_mm_num_tokens(mm_inputs,
|
||||||
|
mm_embeddings_only=mm_embeddings_only)
|
||||||
|
|
||||||
|
def get_mm_max_contiguous_tokens(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Optional[Mapping[str, int]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns the maximum length of the multimodal (image placeholders+text)
|
||||||
|
tokens, including any break/text tokens in-between image embeddings.
|
||||||
|
|
||||||
|
<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>
|
||||||
|
Returns 9, even when the number of image embeddings is 6.
|
||||||
|
|
||||||
|
This is important to take into account when profiling and
|
||||||
|
initializing the encoder cache size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._get_mm_max_tokens(seq_len,
|
||||||
|
mm_counts,
|
||||||
|
mm_embeddings_only=False)
|
||||||
|
|||||||
@ -129,7 +129,7 @@ class MultiModalRegistry:
|
|||||||
seq_len = model_config.max_model_len
|
seq_len = model_config.max_model_len
|
||||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||||
|
|
||||||
return profiler.get_mm_max_tokens(
|
return profiler.get_mm_max_contiguous_tokens(
|
||||||
seq_len,
|
seq_len,
|
||||||
{
|
{
|
||||||
modality: 1
|
modality: 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user