From 0c6e40bbaa4707528286a1e7bf17c90c88a1d920 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 21 Aug 2025 16:00:16 +0800 Subject: [PATCH] [Refactor] Simplify code for MM budget (#23310) Signed-off-by: DarkLight1337 --- vllm/v1/core/encoder_cache_manager.py | 58 +++++++++++++++------------ vllm/v1/worker/gpu_model_runner.py | 18 +++------ vllm/v1/worker/tpu_model_runner.py | 13 ++---- vllm/v1/worker/utils.py | 40 +++++++++--------- 4 files changed, 59 insertions(+), 70 deletions(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index faf5c132f8640..0b9da60c67dee 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +from collections.abc import Mapping from typing import TYPE_CHECKING from vllm.logger import init_logger @@ -188,35 +188,47 @@ def compute_encoder_budget( - Space budget for encoder cache size, in unit of number of tokens in the input sequence. """ + if mm_registry.supports_multimodal_inputs(model_config): + max_tokens_by_modality = mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(model_config) - if not mm_registry.supports_multimodal_inputs(model_config): - return 0, 0 + return compute_mm_encoder_budget( + scheduler_config, + max_tokens_by_modality, + ) - # TODO: handle encoder-decoder models once we support them. - ( - encoder_compute_budget, - encoder_cache_size, - ) = _compute_encoder_budget_multimodal( - model_config, - scheduler_config, - mm_registry, - ) - - return encoder_compute_budget, encoder_cache_size + return compute_text_encoder_budget(scheduler_config) -def _compute_encoder_budget_multimodal( - model_config: "ModelConfig", +def compute_text_encoder_budget( + scheduler_config: "SchedulerConfig") -> tuple[int, int]: + """Compute the encoder cache budget based on the model and scheduler + configurations for a text-only model. + + Args: + scheduler_config: Scheduler configuration. + + Returns: + - Compute budget for encoder execution, in unit of number of tokens + in the input sequence. + - Space budget for encoder cache size, in unit of number of tokens + in the input sequence. + """ + # Currently text-only encoder-decoder models are not supported + return 0, 0 + + +def compute_mm_encoder_budget( scheduler_config: "SchedulerConfig", - mm_registry: MultiModalRegistry, + max_tokens_by_modality: Mapping[str, int], ) -> tuple[int, int]: """Compute the encoder cache budget based on the model and scheduler configurations for a multimodal model. Args: - model_config: Model configuration. scheduler_config: Scheduler configuration. - mm_registry: Provides information about the token cost. + max_tokens_by_modality: The maximum number of tokens for each + non-text modality. Returns: - Compute budget for encoder execution, in unit of number of tokens @@ -225,18 +237,14 @@ def _compute_encoder_budget_multimodal( in the input sequence. """ - max_tokens_by_modality_dict = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) - - if not max_tokens_by_modality_dict: + if not max_tokens_by_modality: logger.warning( "All non-text modalities supported by the model have been " "explicitly disabled via limit_mm_per_prompt. Encoder cache will " "not be initialized.") return 0, 0 - _, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), - key=lambda item: item[1]) + max_tokens_per_mm_item = max(max_tokens_by_modality.values()) if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item > scheduler_config.max_num_batched_tokens): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cc86f9826491f..7caa873be4442 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -341,10 +341,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.model_config, self.scheduler_config, self.mm_registry, - max_model_len=self.max_model_len, - max_num_reqs=self.max_num_reqs, - ) if self.supports_mm_inputs \ - else None) + ) if self.supports_mm_inputs else None) self.reorder_batch_threshold: Optional[int] = None @@ -669,7 +666,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_budget = self.mm_budget assert mm_budget is not None - dummy_modality, _ = mm_budget.get_modality_with_max_tokens() + dummy_modality = mm_budget.get_modality_with_max_tokens() return self._get_mm_dummy_batch(dummy_modality, num_seqs) @@ -2595,14 +2592,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget \ + .max_items_per_batch_by_modality[dummy_modality] logger.info( "Encoder cache will be initialized with a budget of " diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0f569500cdf6b..2a8d65948d574 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -292,8 +292,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.model_config, self.scheduler_config, self.mm_registry, - max_model_len=self.max_model_len, - max_num_reqs=self.max_num_reqs, ) if self.supports_mm_inputs else None) if not self.use_spmd: @@ -1545,14 +1543,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget \ + .max_items_per_batch_by_modality[dummy_modality] logger.info( "Encoder cache will be initialized with a budget of " diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index b138f11af1eb1..c7ccd2e254976 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -12,7 +12,7 @@ from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.registry import MultiModalRegistry from vllm.v1.attention.backends.utils import AttentionMetadataBuilder -from vllm.v1.core.encoder_cache_manager import compute_encoder_budget +from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec if TYPE_CHECKING: @@ -27,9 +27,6 @@ class MultiModalBudget: model_config: ModelConfig, scheduler_config: SchedulerConfig, mm_registry: MultiModalRegistry, - *, - max_model_len: int, - max_num_reqs: int, ) -> None: super().__init__() @@ -37,25 +34,25 @@ class MultiModalBudget: self.scheduler_config = scheduler_config self.mm_registry = mm_registry - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, - mm_registry=mm_registry, - ) - - self.max_num_encoder_input_tokens = encoder_compute_budget - self.encoder_cache_size = encoder_cache_size - self.max_model_len = max_model_len - self.max_num_reqs = max_num_reqs + self.max_model_len = model_config.max_model_len + self.max_num_reqs = scheduler_config.max_num_seqs self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config) - max_items_per_prompt_by_modality = dict[str, int]() - max_items_per_batch_by_modality = dict[str, int]() - max_tokens_by_modality = mm_registry \ .get_max_tokens_per_item_by_nonzero_modality(model_config) + encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( + scheduler_config, + max_tokens_by_modality, + ) + + self.encoder_compute_budget = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size + + max_items_per_prompt_by_modality = dict[str, int]() + max_items_per_batch_by_modality = dict[str, int]() + for modality, max_tokens in max_tokens_by_modality.items(): ( max_items_per_prompt, @@ -69,15 +66,14 @@ class MultiModalBudget: self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality self.max_items_per_batch_by_modality = max_items_per_batch_by_modality - def get_modality_with_max_tokens(self) -> tuple[str, int]: + def get_modality_with_max_tokens(self) -> str: max_tokens_by_modality = self.max_tokens_by_modality - modality, max_tokens = max(max_tokens_by_modality.items(), - key=lambda item: item[1]) + modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1]) - return modality, max_tokens + return modality def get_encoder_budget(self) -> int: - return min(self.max_num_encoder_input_tokens, self.encoder_cache_size) + return min(self.encoder_compute_budget, self.encoder_cache_size) def get_max_items( self,