mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 11:02:16 +08:00
[Misc] Simplify max tokens in multimodal registry (#27500)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
b853540388
commit
4c5f632165
@ -355,7 +355,11 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
mm_counts=mm_counts,
|
mm_counts=mm_counts,
|
||||||
)
|
)
|
||||||
if max_tokens_per_item is not None:
|
if max_tokens_per_item is not None:
|
||||||
return max_tokens_per_item
|
return {
|
||||||
|
modality: max_tokens
|
||||||
|
for modality, max_tokens in max_tokens_per_item.items()
|
||||||
|
if mm_counts.get(modality, 0) > 0
|
||||||
|
}
|
||||||
|
|
||||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||||
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
|
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
|
||||||
@ -375,5 +379,4 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
This is important to take into account when profiling and
|
This is important to take into account when profiling and
|
||||||
initializing the encoder cache size.
|
initializing the encoder cache size.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
|
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
|
||||||
|
|||||||
@ -152,6 +152,7 @@ class MultiModalRegistry:
|
|||||||
model_config: "ModelConfig",
|
model_config: "ModelConfig",
|
||||||
*,
|
*,
|
||||||
cache: BaseMultiModalProcessorCache | None = None,
|
cache: BaseMultiModalProcessorCache | None = None,
|
||||||
|
profiler_limits: Mapping[str, int] | None = None,
|
||||||
) -> Mapping[str, int]:
|
) -> Mapping[str, int]:
|
||||||
"""
|
"""
|
||||||
Get the maximum number of tokens per data item from each modality based
|
Get the maximum number of tokens per data item from each modality based
|
||||||
@ -164,40 +165,15 @@ class MultiModalRegistry:
|
|||||||
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
profiler: MultiModalProfiler = MultiModalProfiler(processor)
|
||||||
|
|
||||||
seq_len = model_config.max_model_len
|
seq_len = model_config.max_model_len
|
||||||
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
|
profiler_limits = (
|
||||||
|
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
|
||||||
|
)
|
||||||
|
|
||||||
return profiler.get_mm_max_contiguous_tokens(
|
return profiler.get_mm_max_contiguous_tokens(
|
||||||
seq_len,
|
seq_len,
|
||||||
{modality: 1 for modality, limit in mm_limits.items() if limit > 0},
|
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_max_tokens_per_item_by_nonzero_modality(
|
|
||||||
self,
|
|
||||||
model_config: "ModelConfig",
|
|
||||||
*,
|
|
||||||
cache: BaseMultiModalProcessorCache | None = None,
|
|
||||||
) -> Mapping[str, int]:
|
|
||||||
"""
|
|
||||||
Get the maximum number of tokens per data item from each modality based
|
|
||||||
on underlying model configuration, excluding modalities that user
|
|
||||||
explicitly disabled via `limit_mm_per_prompt`.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
This is currently directly used only in V1 for profiling the memory
|
|
||||||
usage of a model.
|
|
||||||
"""
|
|
||||||
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
|
|
||||||
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
|
|
||||||
model_config,
|
|
||||||
cache=cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
key: max_tokens_per_mm_item
|
|
||||||
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
|
|
||||||
if mm_limits[key] > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_mm_limits_per_prompt(
|
def get_mm_limits_per_prompt(
|
||||||
self,
|
self,
|
||||||
model_config: "ModelConfig",
|
model_config: "ModelConfig",
|
||||||
@ -369,7 +345,7 @@ class MultiModalRegistry:
|
|||||||
"""
|
"""
|
||||||
if not model_config.is_encoder_decoder:
|
if not model_config.is_encoder_decoder:
|
||||||
return 0
|
return 0
|
||||||
max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config)
|
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
|
||||||
if not max_tokens:
|
if not max_tokens:
|
||||||
# TODO - this function assumes encoder-decoder models are
|
# TODO - this function assumes encoder-decoder models are
|
||||||
# multimodal. This will need to change when adding support for more
|
# multimodal. This will need to change when adding support for more
|
||||||
|
|||||||
@ -264,8 +264,8 @@ def compute_encoder_budget(
|
|||||||
from the input sequence.
|
from the input sequence.
|
||||||
"""
|
"""
|
||||||
if mm_registry.supports_multimodal_inputs(model_config):
|
if mm_registry.supports_multimodal_inputs(model_config):
|
||||||
max_tokens_by_modality = (
|
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
|
||||||
mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config)
|
model_config
|
||||||
)
|
)
|
||||||
|
|
||||||
return compute_mm_encoder_budget(
|
return compute_mm_encoder_budget(
|
||||||
|
|||||||
@ -42,10 +42,10 @@ class MultiModalBudget:
|
|||||||
|
|
||||||
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache)
|
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache)
|
||||||
|
|
||||||
max_tokens_by_modality = (
|
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
|
||||||
mm_registry.get_max_tokens_per_item_by_nonzero_modality(
|
model_config,
|
||||||
model_config, cache=cache
|
cache=cache,
|
||||||
)
|
profiler_limits=self.mm_limits,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
|
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user