[BugFix][MM]support VLLM_RANDOMIZE_DP_DUMMY_INPUTS (#30472)

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Xingyu Liu 2025-12-11 13:00:15 -08:00 committed by GitHub
parent cf3eacfe58
commit 90d6cf921f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import gc
import itertools
import time
@ -3892,19 +3893,21 @@ class GPUModelRunner(
return {}
@contextmanager
def maybe_randomize_inputs(self, input_ids: torch.Tensor):
def maybe_randomize_inputs(
self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None
):
"""
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection
- during profile_run
- during DP rank dummy run
"""
dp_size = self.vllm_config.parallel_config.data_parallel_size
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
if not randomize_inputs:
yield
else:
import functools
elif input_ids is not None:
@functools.cache
def rand_input_ids() -> torch.Tensor:
@ -3912,13 +3915,27 @@ class GPUModelRunner(
self.input_ids.gpu,
low=0,
high=self.model_config.get_vocab_size(),
dtype=input_ids.dtype,
)
logger.debug_once("Randomizing dummy data for DP Rank")
logger.debug_once("Randomizing dummy input_ids for DP Rank")
input_ids.copy_(rand_input_ids()[: input_ids.size(0)], non_blocking=True)
yield
input_ids.fill_(0)
else:
@functools.cache
def rand_inputs_embeds() -> torch.Tensor:
return torch.randn_like(
self.inputs_embeds.gpu,
)
assert inputs_embeds is not None
logger.debug_once("Randomizing dummy inputs_embeds for DP Rank")
inputs_embeds.copy_(
rand_inputs_embeds()[: inputs_embeds.size(0)], non_blocking=True
)
yield
inputs_embeds.fill_(0)
def _get_mm_dummy_batch(
self,
@ -4167,7 +4184,7 @@ class GPUModelRunner(
num_tokens_across_dp[:] = num_tokens_padded
with (
self.maybe_randomize_inputs(input_ids),
self.maybe_randomize_inputs(input_ids, inputs_embeds),
set_forward_context(
attn_metadata,
self.vllm_config,