[MM] Allow skipping memory profiling for multimodal models. (#22950)

Signed-off-by: Roger Wang <hey@rogerw.me>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Roger Wang 2025-08-15 04:41:38 -07:00 committed by GitHub
parent 3e6dd40016
commit 49252cf59e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 120 additions and 89 deletions

View File

@ -388,6 +388,10 @@ class ModelConfig:
interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string. Defaults to False."""
skip_mm_profiling: bool = False
"""When enabled, skips multimodal memory profiling and only profiles with
language backbone model during engine initialization.
"""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
@ -837,7 +841,8 @@ class ModelConfig:
media_io_kwargs=self.media_io_kwargs,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb,
interleave_mm_strings=self.interleave_mm_strings)
interleave_mm_strings=self.interleave_mm_strings,
skip_mm_profiling=self.skip_mm_profiling)
return None
@ -2511,6 +2516,16 @@ class MultiModalConfig:
Enable fully interleaved support for multimodal prompts.
"""
skip_mm_profiling: bool = False
"""
When enabled, skips multimodal memory profiling and only profiles with
language backbone model during engine initialization.
This reduces engine startup time but shifts the responsibility to users for
estimating the peak memory usage of the activation of multimodal encoder and
embedding cache.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@ -350,6 +350,7 @@ class EngineArgs:
MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
@ -716,6 +717,8 @@ class EngineArgs:
multimodal_group.add_argument(
"--interleave-mm-strings",
**multimodal_kwargs["interleave_mm_strings"])
multimodal_group.add_argument("--skip-mm-profiling",
**multimodal_kwargs["skip_mm_profiling"])
# LoRA related configs
lora_kwargs = get_kwargs(LoRAConfig)
@ -918,6 +921,7 @@ class EngineArgs:
limit_mm_per_prompt=self.limit_mm_per_prompt,
interleave_mm_strings=self.interleave_mm_strings,
media_io_kwargs=self.media_io_kwargs,
skip_mm_profiling=self.skip_mm_profiling,
use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,

View File

@ -2479,50 +2479,56 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
if self.supports_mm_inputs:
mm_budget = self.mm_budget
assert mm_budget is not None
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
(
dummy_modality,
max_tokens,
) = mm_budget.get_modality_with_max_tokens()
(
max_mm_items_per_prompt,
max_mm_items_per_batch,
) = mm_budget.get_max_items(dummy_modality, max_tokens)
if self.model_config.multimodal_config.skip_mm_profiling:
logger.info(
"Encoder cache will be initialized with a budget of "
"%s tokens, and profiled with %s %s items of the maximum "
"feature size.",
encoder_budget,
max_mm_items_per_batch,
dummy_modality,
)
"Skipping memory profiling for multimodal encoder and "
"encoder cache.")
else:
mm_budget = self.mm_budget
assert mm_budget is not None
# Create dummy batch of multimodal inputs.
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
dummy_modality,
max_mm_items_per_batch,
)
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
(
dummy_modality,
max_tokens,
) = mm_budget.get_modality_with_max_tokens()
(
max_mm_items_per_prompt,
max_mm_items_per_batch,
) = mm_budget.get_max_items(dummy_modality, max_tokens)
# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
logger.info(
"Encoder cache will be initialized with a budget of "
"%s tokens, and profiled with %s %s items of the "
"maximum feature size.",
encoder_budget,
max_mm_items_per_batch,
dummy_modality,
)
sanity_check_mm_encoder_outputs(
dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch,
)
# Create dummy batch of multimodal inputs.
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
dummy_modality,
max_mm_items_per_batch,
)
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(
enumerate(dummy_encoder_outputs))
# Run multimodal encoder.
dummy_encoder_outputs = \
self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
sanity_check_mm_encoder_outputs(
dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch,
)
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(
enumerate(dummy_encoder_outputs))
# Add `is_profile` here to pre-allocate communication buffers
hidden_states, last_hidden_states \

View File

@ -1529,60 +1529,66 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> None:
# Profile with multimodal encoder & encoder cache.
if self.supports_mm_inputs:
mm_budget = self.mm_budget
assert mm_budget is not None
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
(
dummy_modality,
max_tokens,
) = mm_budget.get_modality_with_max_tokens()
(
max_mm_items_per_prompt,
max_mm_items_per_batch,
) = mm_budget.get_max_items(dummy_modality, max_tokens)
if self.model_config.multimodal_config.skip_mm_profiling:
logger.info(
"Encoder cache will be initialized with a budget of "
"%s tokens, and profiled with %s %s items of the maximum "
"feature size.",
encoder_budget,
max_mm_items_per_batch,
dummy_modality,
)
"Skipping memory profiling for multimodal encoder and "
"encoder cache.")
else:
mm_budget = self.mm_budget
assert mm_budget is not None
# Create dummy batch of multimodal inputs.
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
dummy_modality,
max_mm_items_per_batch,
)
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
(
dummy_modality,
max_tokens,
) = mm_budget.get_modality_with_max_tokens()
(
max_mm_items_per_prompt,
max_mm_items_per_batch,
) = mm_budget.get_max_items(dummy_modality, max_tokens)
# Run multimodal encoder.
# Isolate encoder graph from post-processing to minimize
# impact of recompilation until it's fixed.
start = time.perf_counter()
xm.mark_step()
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
xm.mark_step()
xm.wait_device_ops()
end = time.perf_counter()
logger.info(
"Multimodal Encoder profiling finished in in %.2f [secs].",
end - start)
logger.info(
"Encoder cache will be initialized with a budget of "
"%s tokens, and profiled with %s %s items of the "
"maximum feature size.",
encoder_budget,
max_mm_items_per_batch,
dummy_modality,
)
sanity_check_mm_encoder_outputs(
dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch,
)
# Create dummy batch of multimodal inputs.
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
dummy_modality,
max_mm_items_per_batch,
)
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(
enumerate(dummy_encoder_outputs))
# Run multimodal encoder.
# Isolate encoder graph from post-processing to minimize
# impact of recompilation until it's fixed.
start = time.perf_counter()
xm.mark_step()
dummy_encoder_outputs = \
self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
xm.mark_step()
xm.wait_device_ops()
end = time.perf_counter()
logger.info(
"Multimodal Encoder profiling finished in %.2f [secs].",
end - start)
sanity_check_mm_encoder_outputs(
dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch,
)
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(
enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape.
self._dummy_run(num_tokens, self.num_reqs_max_model_len,