From 49252cf59e70b1e1a8bae21da929f6d51e9acce4 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 15 Aug 2025 04:41:38 -0700 Subject: [PATCH] [MM] Allow skipping memory profiling for multimodal models. (#22950) Signed-off-by: Roger Wang Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm/config/__init__.py | 17 ++++- vllm/engine/arg_utils.py | 4 ++ vllm/v1/worker/gpu_model_runner.py | 84 ++++++++++++----------- vllm/v1/worker/tpu_model_runner.py | 104 +++++++++++++++-------------- 4 files changed, 120 insertions(+), 89 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index b4ea15ef5a0f8..a2e93c344b3f3 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -388,6 +388,10 @@ class ModelConfig: interleave_mm_strings: bool = False """Enable fully interleaved support for multimodal prompts, while using --chat-template-content-format=string. Defaults to False.""" + skip_mm_profiling: bool = False + """When enabled, skips multimodal memory profiling and only profiles with + language backbone model during engine initialization. + """ media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) """Additional args passed to process media inputs, keyed by modalities. For example, to set num_frames for video, set @@ -837,7 +841,8 @@ class ModelConfig: media_io_kwargs=self.media_io_kwargs, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, - interleave_mm_strings=self.interleave_mm_strings) + interleave_mm_strings=self.interleave_mm_strings, + skip_mm_profiling=self.skip_mm_profiling) return None @@ -2511,6 +2516,16 @@ class MultiModalConfig: Enable fully interleaved support for multimodal prompts. """ + skip_mm_profiling: bool = False + """ + When enabled, skips multimodal memory profiling and only profiles with + language backbone model during engine initialization. + + This reduces engine startup time but shifts the responsibility to users for + estimating the peak memory usage of the activation of multimodal encoder and + embedding cache. + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index dd1072da08447..31de2ede7a380 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -350,6 +350,7 @@ class EngineArgs: MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb + skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled @@ -716,6 +717,8 @@ class EngineArgs: multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]) + multimodal_group.add_argument("--skip-mm-profiling", + **multimodal_kwargs["skip_mm_profiling"]) # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) @@ -918,6 +921,7 @@ class EngineArgs: limit_mm_per_prompt=self.limit_mm_per_prompt, interleave_mm_strings=self.interleave_mm_strings, media_io_kwargs=self.media_io_kwargs, + skip_mm_profiling=self.skip_mm_profiling, use_async_output_proc=not self.disable_async_output_proc, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8fb9641844fb5..703092ca9feeb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2479,50 +2479,56 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. if self.supports_mm_inputs: - mm_budget = self.mm_budget - assert mm_budget is not None - - # TODO: handle encoder-decoder models once we support them. - if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) - + if self.model_config.multimodal_config.skip_mm_profiling: logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the maximum " - "feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) + "Skipping memory profiling for multimodal encoder and " + "encoder cache.") + else: + mm_budget = self.mm_budget + assert mm_budget is not None - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) + # TODO: handle encoder-decoder models once we support them. + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + ( + dummy_modality, + max_tokens, + ) = mm_budget.get_modality_with_max_tokens() + ( + max_mm_items_per_prompt, + max_mm_items_per_batch, + ) = mm_budget.get_max_items(dummy_modality, max_tokens) - # Run multimodal encoder. - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the " + "maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + # Run multimodal encoder. + dummy_encoder_outputs = \ + self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict( + enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states \ diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 46262284e3333..f7e68edba3a13 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1529,60 +1529,66 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) -> None: # Profile with multimodal encoder & encoder cache. if self.supports_mm_inputs: - mm_budget = self.mm_budget - assert mm_budget is not None - - # TODO: handle encoder-decoder models once we support them. - if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) - + if self.model_config.multimodal_config.skip_mm_profiling: logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the maximum " - "feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) + "Skipping memory profiling for multimodal encoder and " + "encoder cache.") + else: + mm_budget = self.mm_budget + assert mm_budget is not None - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) + # TODO: handle encoder-decoder models once we support them. + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + ( + dummy_modality, + max_tokens, + ) = mm_budget.get_modality_with_max_tokens() + ( + max_mm_items_per_prompt, + max_mm_items_per_batch, + ) = mm_budget.get_max_items(dummy_modality, max_tokens) - # Run multimodal encoder. - # Isolate encoder graph from post-processing to minimize - # impact of recompilation until it's fixed. - start = time.perf_counter() - xm.mark_step() - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - xm.mark_step() - xm.wait_device_ops() - end = time.perf_counter() - logger.info( - "Multimodal Encoder profiling finished in in %.2f [secs].", - end - start) + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the " + "maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + # Run multimodal encoder. + # Isolate encoder graph from post-processing to minimize + # impact of recompilation until it's fixed. + start = time.perf_counter() + xm.mark_step() + dummy_encoder_outputs = \ + self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + xm.mark_step() + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal Encoder profiling finished in %.2f [secs].", + end - start) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict( + enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. self._dummy_run(num_tokens, self.num_reqs_max_model_len,