[Bugfix] Make DP padding optional in coordinate_batch_across_dp (#26375)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-10-10 07:53:33 -07:00 committed by GitHub
parent 0e67102d93
commit ae9d0e7da5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 123 additions and 42 deletions

View File

@ -12,6 +12,7 @@ import torch
import vllm.envs as envs
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices
if TYPE_CHECKING:
@ -278,7 +279,19 @@ def set_forward_context(
if vllm_config.parallel_config.data_parallel_size > 1 and (
attn_metadata is not None or num_tokens is not None
):
assert num_tokens_across_dp is not None
# If num_tokens_across_dp hasn't already been initialized, then
# initialize it here. Both DP padding and Microbatching will be
# disabled.
if num_tokens_across_dp is None:
assert ubatch_slices is None
assert num_tokens is not None
_, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=vllm_config.parallel_config,
allow_microbatching=False,
allow_dp_padding=False,
)
assert num_tokens_across_dp is not None
dp_metadata = DPMetadata.make(
vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
)

View File

@ -7,7 +7,7 @@ import torch
import torch.distributed as dist
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.ubatch_utils import (
@ -37,6 +37,7 @@ def _get_device_and_group(parallel_config: ParallelConfig):
def _run_ar(
should_ubatch: bool,
should_dp_pad: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
parallel_config: ParallelConfig,
@ -44,10 +45,11 @@ def _run_ar(
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
device, group = _get_device_and_group(parallel_config)
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0
dist.all_reduce(tensor, group=group)
return tensor
@ -72,10 +74,26 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool:
return should_ubatch
def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor:
num_tokens_across_dp = tensor[1, :]
if should_dp_pad:
# If DP padding is enabled, ensure that each rank is processing the same number
# of tokens
max_num_tokens = int(num_tokens_across_dp.max().item())
return torch.tensor(
[max_num_tokens] * len(num_tokens_across_dp),
device="cpu",
dtype=torch.int32,
)
else:
return num_tokens_across_dp.cpu()
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
should_attempt_ubatching: bool,
should_attempt_dp_padding: bool,
parallel_config: ParallelConfig,
) -> tuple[bool, Optional[torch.Tensor]]:
"""
@ -83,57 +101,88 @@ def _synchronize_dp_ranks(
run with microbatching or none of them do.
2. Determines the total number of tokens that each rank will run.
All ranks will be padded out so that the run with the same number
of tokens
When running microbatched or if should_attempt_dp_padding is True, all
ranks will be padded out so that the run with the same number of tokens
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding.
tokens per-microbatch for each DP rank including any DP padding.
]
"""
assert num_tokens_padded >= num_tokens_unpadded
# First we coordinate between the DP ranks via an All Reduce
# Coordinate between the DP ranks via an All Reduce
# to determine the total number of tokens that each rank
# will run and if we are using ubatching or not.
tensor = _run_ar(
should_ubatch=should_attempt_ubatching,
should_dp_pad=should_attempt_dp_padding,
orig_num_tokens_per_ubatch=num_tokens_unpadded,
padded_num_tokens_per_ubatch=num_tokens_padded,
parallel_config=parallel_config,
)
# Ensure that each rank is processing the same nuber of tokens
num_tokens_across_dp = tensor[1, :]
max_num_tokens = int(num_tokens_across_dp.max().item())
num_tokens_after_padding = torch.tensor(
[max_num_tokens] * len(num_tokens_across_dp), device="cpu", dtype=torch.int32
)
should_dp_pad = bool(torch.all(tensor[3] == 1).item())
# DP ranks should all have the same value for should_attempt_dp_padding.
assert should_attempt_dp_padding == should_dp_pad
# Check conditions for microbatching
should_ubatch = _post_process_ubatch(tensor)
if should_ubatch and not should_dp_pad:
if is_global_first_rank():
logger.debug(
"Microbatching has been triggered and requires DP padding. "
"Enabling DP padding even though it has been explicitly "
"disabled."
)
should_dp_pad = True
# Pad all DP ranks up to the maximum token count across ranks if
# should_dp_pad is True
num_tokens_after_padding = _post_process_dp_padding(
tensor,
should_dp_pad,
)
return should_ubatch, num_tokens_after_padding
def coordinate_batch_across_dp(
num_scheduled_tokens_per_request: np.ndarray,
num_tokens_unpadded: int,
num_tokens_padded: int,
parallel_config: ParallelConfig,
allow_microbatching: bool,
uniform_decode: bool,
allow_dp_padding: bool,
parallel_config: ParallelConfig,
num_tokens_padded: Optional[int] = None,
uniform_decode: Optional[bool] = None,
num_scheduled_tokens_per_request: Optional[np.ndarray] = None,
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
Args:
num_tokens_unpadded: Number of tokens without accounting for padding
allow_microbatching: If microbatching should be attempted
allow_dp_padding: If all DP ranks should be padded up to the same value
parallel_config: The parallel config
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
TP, etc)
uniform_decode: Only used if allow_microbatching is True. True if the batch
only contains single token decodes
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
number of tokens per request.
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding.
tokens per-microbatch for each DP rank including padding. Will be
padded up to the max value across all DP ranks when allow_dp_padding
is True.
]
"""
@ -141,21 +190,25 @@ def coordinate_batch_across_dp(
# Early exit.
return None, None
# Check preconditions for microbatching
should_attempt_ubatching = check_ubatch_thresholds(
parallel_config,
num_tokens_unpadded,
uniform_decode=uniform_decode,
)
# If the caller has explicitly enabled microbatching.
should_attempt_ubatching = False
if allow_microbatching:
# Check preconditions for microbatching
assert uniform_decode is not None
should_attempt_ubatching = check_ubatch_thresholds(
parallel_config,
num_tokens_unpadded,
uniform_decode=uniform_decode,
)
# If the caller has explicitly disabled microbatching.
if not allow_microbatching:
should_attempt_ubatching = False
if num_tokens_padded is None:
num_tokens_padded = num_tokens_unpadded
(should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
allow_dp_padding,
parallel_config,
)
@ -170,6 +223,7 @@ def coordinate_batch_across_dp(
assert num_tokens_after_padding is not None
token_split_point = int(num_tokens_after_padding[0].item()) // 2
assert num_scheduled_tokens_per_request is not None
ubatch_slices = create_ubatch_slices(
num_scheduled_tokens_per_request, token_split_point
)

View File

@ -1178,13 +1178,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode = (
max_num_scheduled_tokens == self.uniform_decode_query_len
) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set enforce_eager on the prefiller in
# a P/D setup and still use CUDA graphs (enabled by this padding) on the
# decoder.
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
num_scheduled_tokens,
num_tokens_unpadded,
num_tokens_padded,
self.parallel_config,
True,
uniform_decode,
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.parallel_config,
allow_microbatching=True,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens,
)
self.seq_lens.np[:num_reqs] = (
@ -2436,12 +2444,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
use_cascade_attn,
) = self._prepare_inputs(scheduler_output)
dp_rank = self.parallel_config.data_parallel_rank
if ubatch_slices:
assert num_tokens_across_dp is not None
num_input_tokens = int(num_tokens_across_dp[0].item())
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
elif num_tokens_across_dp is not None:
num_input_tokens = int(num_tokens_across_dp[0].item())
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
else:
num_input_tokens = self._get_num_input_tokens(
scheduler_output.total_num_scheduled_tokens
@ -3256,19 +3265,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
# Disable DP padding when running eager
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
# We currently only microbatch if the number of tokens is
# over a certain threshold.
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
num_scheduled_tokens,
total_num_scheduled_tokens,
total_num_scheduled_tokens,
self.vllm_config.parallel_config,
allow_microbatching,
uniform_decode,
num_tokens_unpadded=total_num_scheduled_tokens,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=allow_microbatching,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=total_num_scheduled_tokens,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens,
)
num_tokens_after_padding = num_tokens
if num_tokens_across_dp is not None:
num_tokens_after_padding = int(num_tokens_across_dp[0])
dp_rank = self.parallel_config.data_parallel_rank
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
attn_metadata: Optional[PerLayerAttnMetadata] = None