[Core] Simplify the Dp padding/should ubatch coordination logic (#25768)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Sage Moore 2025-10-06 18:57:49 -07:00 committed by GitHub
parent c50901f3b9
commit 2111b4643c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 297 additions and 462 deletions

View File

@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import (
split_attn_metadata,
split_decodes_and_prefills,
)
from vllm.v1.worker.ubatch_splitting import create_ubatch_slices
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
@pytest.fixture

View File

@ -152,6 +152,10 @@ class ParallelConfig:
threshold, microbatching will be used. Otherwise, the request will be
processed in a single batch."""
disable_nccl_for_dp_synchronization: bool = False
"""Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py
to use Gloo instead of NCCL for its all reduce"""
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

View File

@ -365,6 +365,9 @@ class EngineArgs:
enable_dbo: bool = ParallelConfig.enable_dbo
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
disable_nccl_for_dp_synchronization: bool = (
ParallelConfig.disable_nccl_for_dp_synchronization
)
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
enable_eplb: bool = ParallelConfig.enable_eplb
expert_placement_strategy: ExpertPlacementStrategy = (
@ -760,6 +763,10 @@ class EngineArgs:
"--dbo-prefill-token-threshold",
**parallel_kwargs["dbo_prefill_token_threshold"],
)
parallel_group.add_argument(
"--disable-nccl-for-dp-synchronization",
**parallel_kwargs["disable_nccl_for_dp_synchronization"],
)
parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
parallel_group.add_argument(
@ -1437,6 +1444,7 @@ class EngineArgs:
enable_dbo=self.enable_dbo,
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
enable_eplb=self.enable_eplb,
eplb_config=self.eplb_config,
expert_placement_strategy=self.expert_placement_strategy,

View File

@ -95,7 +95,6 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False
VLLM_DISABLE_PYNCCL: bool = False
VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
@ -830,12 +829,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DISABLED_KERNELS": lambda: []
if "VLLM_DISABLED_KERNELS" not in os.environ
else os.environ["VLLM_DISABLED_KERNELS"].split(","),
# Swaps the all reduce backend that we use to coordinate the DP padding
# information from NCCL to gloo.
"VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION": lambda: (
os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower()
in ("true", "1")
),
# Disable pynccl (using torch.distributed instead)
"VLLM_DISABLE_PYNCCL": lambda: (
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")

View File

@ -8,13 +8,11 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
from vllm.v1.worker.ubatch_utils import UBatchSlices
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@ -87,129 +85,22 @@ class DPMetadata:
# NOTE: local_sizes should only be set by the chunked_sizes context manager
local_sizes: Optional[list[int]] = None
@staticmethod
def num_tokens_across_dp(
num_tokens: int, dp_size: int, dp_rank: int
) -> torch.Tensor:
"""
Gather the num_tokens across all DP ranks and return results in a
CPU tensor of size dp_size.
"""
from vllm.distributed.parallel_state import get_dp_group
device = current_platform.device_type
group = get_dp_group().device_group
# Transfering this tensor from GPU to CPU will introduce a GPU sync
# point that could adversely affect performance of vllm with asynch
# scheduling. This environment variable exists to quickly disable
# this optimization if we run into this case.
if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
logger.info_once(
"Using CPU all reduce to syncronize DP padding between ranks."
)
device = "cpu"
group = get_dp_group().cpu_group
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(
num_tokens_across_dp, device=device, dtype=torch.int32
)
dist.all_reduce(num_tokens_tensor, group=group)
return num_tokens_tensor.cpu()
# Get the cumulative tokens across sequence parallel ranks.
# In this case the input to the MoEs will be distributed w.r.t both
# DP and TP rank.
# When sp_size==1, this is just the cummulative num tokens across DP.
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
num_tokens_across_sp_cpu = (
self.num_tokens_across_dp_cpu - 1 + sp_size
) // sp_size
num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
@staticmethod
def should_ubatch_across_dp(
should_ubatch: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
dp_size: int,
dp_rank: int,
) -> tuple[bool, Optional[torch.Tensor]]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do. If this function decides
not to run with microbatching. It will "abort" meaning that no padding
information will be returned to the caller. It will return (False, None)
2. Determines the total number of tokens that each rank will run.
All ranks will be padded out so that the run with the same number
of tokens
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
None if should_ubatch if False
]
"""
device = current_platform.device_type
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(tensor, group=get_dp_group().device_group)
result: bool = bool(torch.all(tensor[2] == 1).item())
if not result:
return result, None
orig_num_tokens_tensor = tensor[0, :]
padded_num_tokens_tensor = tensor[1, :]
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
logger.debug(
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
)
return False, None
return result, padded_num_tokens_tensor.cpu()
@staticmethod
def make(
parallel_config: ParallelConfig,
attn_metadata: Any,
num_tokens: int,
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None,
num_tokens_across_dp_cpu: torch.Tensor,
) -> "DPMetadata":
assert num_tokens_across_dp_cpu is not None
assert parallel_config.data_parallel_size > 1
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
if attn_metadata is not None and hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = (
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
)
else:
# for v1 attention backends or no attn_metadata
batchsize = num_tokens
batchsize = num_tokens
# If num_tokens_across_dp is None, it will be computed by all_reduce
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
assert (
num_tokens_across_dp_cpu is None
or num_tokens_across_dp_cpu[dp_rank] == batchsize
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
if num_tokens_across_dp_cpu is None:
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
batchsize, dp_size, dp_rank
)
assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
@ -376,11 +267,9 @@ def set_forward_context(
if vllm_config.parallel_config.data_parallel_size > 1 and (
attn_metadata is not None or num_tokens is not None
):
assert num_tokens_across_dp is not None
dp_metadata = DPMetadata.make(
vllm_config.parallel_config,
attn_metadata,
num_tokens or 0,
num_tokens_across_dp,
vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
)
forward_context = create_forward_context(

177
vllm/v1/worker/dp_utils.py Normal file
View File

@ -0,0 +1,177 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.ubatch_utils import (
UBatchSlices,
check_ubatch_thresholds,
create_ubatch_slices,
is_second_ubatch_empty,
)
logger = init_logger(__name__)
def _get_device_and_group(parallel_config: ParallelConfig):
device = current_platform.device_type
group = get_dp_group().device_group
# Transfering this tensor from GPU to CPU will introduce a GPU sync
# point that could adversely affect performance of vllm with asynch
# scheduling. This environment variable exists to quickly disable
# this optimization if we run into this case.
if parallel_config.disable_nccl_for_dp_synchronization:
logger.info_once("Using CPU all reduce to syncronize DP padding between ranks.")
device = "cpu"
group = get_dp_group().cpu_group
return device, group
def _run_ar(
should_ubatch: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
parallel_config: ParallelConfig,
) -> torch.Tensor:
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
device, group = _get_device_and_group(parallel_config)
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
dist.all_reduce(tensor, group=group)
return tensor
def _post_process_ubatch(tensor: torch.Tensor) -> bool:
orig_num_tokens_tensor = tensor[0, :]
padded_num_tokens_tensor = tensor[1, :]
# First determine if we are going to be ubatching.
should_ubatch: bool = bool(torch.all(tensor[2] == 1).item())
if not should_ubatch:
return False
# If the DP ranks are planning to ubatch, make sure that
# there are no "empty" second ubatches
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
logger.debug(
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
)
should_ubatch = False
return should_ubatch
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
should_attempt_ubatching: bool,
parallel_config: ParallelConfig,
) -> tuple[bool, Optional[torch.Tensor]]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do.
2. Determines the total number of tokens that each rank will run.
All ranks will be padded out so that the run with the same number
of tokens
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding.
]
"""
assert num_tokens_padded >= num_tokens_unpadded
# First we coordinate between the DP ranks via an All Reduce
# to determine the total number of tokens that each rank
# will run and if we are using ubatching or not.
tensor = _run_ar(
should_ubatch=should_attempt_ubatching,
orig_num_tokens_per_ubatch=num_tokens_unpadded,
padded_num_tokens_per_ubatch=num_tokens_padded,
parallel_config=parallel_config,
)
# Ensure that each rank is processing the same nuber of tokens
num_tokens_across_dp = tensor[1, :]
max_num_tokens = int(num_tokens_across_dp.max().item())
num_tokens_after_padding = torch.tensor(
[max_num_tokens] * len(num_tokens_across_dp), device="cpu", dtype=torch.int32
)
should_ubatch = _post_process_ubatch(tensor)
return should_ubatch, num_tokens_after_padding
def coordinate_batch_across_dp(
num_scheduled_tokens_per_request: np.ndarray,
num_tokens_unpadded: int,
num_tokens_padded: int,
parallel_config: ParallelConfig,
allow_microbatching: bool,
uniform_decode: bool,
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding.
]
"""
if parallel_config.data_parallel_size == 1:
# Early exit.
return None, None
# Check preconditions for microbatching
should_attempt_ubatching = check_ubatch_thresholds(
parallel_config,
num_tokens_unpadded,
uniform_decode=uniform_decode,
)
# If the caller has explicitly disabled microbatching.
if not allow_microbatching:
should_attempt_ubatching = False
(should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
parallel_config,
)
# Don't microbatch unless every other DP worker is also microbatching
if not should_ubatch:
return (None, num_tokens_after_padding)
# This doesn't actually pad the ubatch slices. It just initializes the
# split point to the padded value so that padding can be applied
# to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation
assert num_tokens_after_padding is not None
token_split_point = int(num_tokens_after_padding[0].item()) // 2
ubatch_slices = create_ubatch_slices(
num_scheduled_tokens_per_request, token_split_point
)
return (ubatch_slices, num_tokens_after_padding)

View File

@ -41,7 +41,7 @@ from vllm.distributed.parallel_state import (
is_global_first_rank,
prepare_communication_buffer_for_model,
)
from vllm.forward_context import BatchDescriptor, DPMetadata, set_forward_context
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
@ -131,12 +131,16 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatch_splitting import check_ubatch_thresholds, ubatch_split
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
UBatchSlices,
check_ubatch_thresholds,
)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from .utils import (
@ -1161,18 +1165,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
num_tokens_unpadded
)
num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded)
uniform_decode = (
max_num_scheduled_tokens == self.uniform_decode_query_len
) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
ubatch_slices, num_tokens_after_padding = ubatch_split(
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
num_scheduled_tokens,
num_tokens_unpadded,
num_tokens_padded,
uniform_decode=uniform_decode,
vllm_config=self.vllm_config,
self.parallel_config,
True,
uniform_decode,
)
self.seq_lens.np[:num_reqs] = (
@ -1405,7 +1408,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_common_attn_metadata,
max_num_scheduled_tokens,
ubatch_slices,
num_tokens_after_padding,
num_tokens_across_dp,
use_cascade_attn,
)
@ -1986,65 +1989,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
log_stats=self.parallel_config.eplb_config.log_balancedness,
)
def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
"""
Determines the total number of tokens that each rank will run.
All ranks will be padded out so that they run with the same number
of tokens
Returns: tuple[
num_pad_tokens: The number of tokens that will be added to the batch
num_tokens_after_padding: A tensor containing the total number of
tokens for each DP rank including padding.
]
"""
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
# For DP: Don't pad when setting enforce_eager.
# This lets us set enforce_eager on the prefiller in a P/D setup and
# still use CUDA graphs (enabled by this padding) on the decoder.
#
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# Early exit.
return 0, None
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank
)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
num_tokens_after_padding = torch.tensor(
[max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32
)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def get_local_padding(self, num_tokens_unpadded: int) -> int:
num_tokens_padded = num_tokens_unpadded
if (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]
):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if (
self.vllm_config.compilation_config.pass_config.enable_sequence_parallelism
and tp_size > 1
):
num_tokens_padded = round_up(num_tokens_unpadded, tp_size)
num_pad_tokens = num_tokens_padded - num_tokens_unpadded
return num_pad_tokens
# This is where the second ubatch is adjusted to account for the padding.
# Should be called after attention metadata creation. This just pads
# the second ubatch slice out to the total number of tokens
@ -2127,13 +2071,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _preprocess(
self,
scheduler_output: "SchedulerOutput",
num_input_tokens: int, # Padded
intermediate_tensors: Optional[IntermediateTensors] = None,
ubatch_slices: Optional[UBatchSlices] = None,
num_tokens_after_padding: Optional[torch.Tensor] = None,
) -> tuple[
int,
int,
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
torch.Tensor,
@ -2141,14 +2082,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dict[str, Any],
]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if ubatch_slices:
assert num_tokens_after_padding is not None
num_input_tokens = int(num_tokens_after_padding[0].item() * 2)
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
elif ubatch_slices is None:
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
@ -2235,8 +2168,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return (
num_scheduled_tokens,
num_input_tokens,
num_tokens_after_padding,
input_ids,
inputs_embeds,
positions,
@ -2506,24 +2437,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_common_attn_metadata,
max_query_len,
ubatch_slices,
num_tokens_after_padding,
num_tokens_across_dp,
use_cascade_attn,
) = self._prepare_inputs(scheduler_output)
if ubatch_slices:
assert num_tokens_across_dp is not None
num_input_tokens = int(num_tokens_across_dp[0].item())
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
elif num_tokens_across_dp is not None:
num_input_tokens = int(num_tokens_across_dp[0].item())
else:
num_input_tokens = self._get_num_input_tokens(
scheduler_output.total_num_scheduled_tokens
)
(
num_scheduled_tokens,
num_input_tokens,
num_tokens_across_dp,
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
) = self._preprocess(
scheduler_output,
intermediate_tensors,
ubatch_slices,
num_tokens_after_padding,
scheduler_output, num_input_tokens, intermediate_tensors
)
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
@ -2548,11 +2485,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
):
cudagraph_runtime_mode = CUDAGraphMode.NONE
# This is currently to get around the assert in the DPMetadata
# where it wants `num_tokens_across_dp` to align with `num_tokens`
if ubatch_slices is not None:
num_input_tokens = ubatch_slices[0].num_tokens
# Run the model.
# Use persistent buffers for CUDA graphs.
with (
@ -3329,36 +3261,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
ubatch_slices = None
num_tokens_after_padding = None
# We currently only microbatch if the number of tokens is
# over a certain threshold.
if self.parallel_config.enable_dbo and allow_microbatching:
ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split(
num_scheduled_tokens,
total_num_scheduled_tokens,
total_num_scheduled_tokens,
uniform_decode=uniform_decode,
vllm_config=self.vllm_config,
)
# Currently when DBO is enabled `ubatch_split` returns
# the num_tokens_after_padding for a single ubatch, but we have 2
# TODO(sage,lucas): this is cruft that should be addressed in the
# padding refactor.
if ubatch_num_tokens_after_padding is not None:
num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
# If we failed to microbatch, currently need to resynchronize
# TODO(lucas,sage): we should be able to avoid this second sync by
# refactoring `get_dp_padding_ubatch` and `get_dp_padding` into
# a single `coordinate_batch_across_dp` function.
if num_tokens_after_padding is None:
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens_after_padding = num_tokens + num_pad
else:
num_tokens_across_dp = num_tokens_after_padding
num_tokens_after_padding = int(num_tokens_after_padding[0].item())
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
num_scheduled_tokens,
total_num_scheduled_tokens,
total_num_scheduled_tokens,
self.vllm_config.parallel_config,
allow_microbatching,
uniform_decode,
)
num_tokens_after_padding = num_tokens
if num_tokens_across_dp is not None:
num_tokens_after_padding = int(num_tokens_across_dp[0])
attn_metadata: Optional[PerLayerAttnMetadata] = None

View File

@ -13,6 +13,7 @@ from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import (
DPMetadata,
create_forward_context,
get_forward_context,
override_forward_context,
@ -409,6 +410,18 @@ class UBatchWrapper:
# We shouldn't be here unless we are running with multiple DP ranks
assert dp_metadata is not None
num_tokens_per_ubatch = (
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
)
dp_size = self.vllm_config.parallel_config.data_parallel_size
ubatch_num_tokens_across_dp = torch.tensor(
[num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32
)
ubatch_dp_metadata = DPMetadata.make(
self.vllm_config.parallel_config,
num_tokens_per_ubatch,
ubatch_num_tokens_across_dp,
)
if (
num_tokens not in self.cudagraphs
@ -422,7 +435,7 @@ class UBatchWrapper:
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
dp_metadata=ubatch_dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)

View File

@ -1,207 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import torch
from vllm.config import ParallelConfig, VllmConfig
from vllm.forward_context import DPMetadata
from vllm.logger import init_logger
from vllm.utils import round_up
from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
UBatchSlices,
is_second_ubatch_empty,
)
logger = init_logger(__name__)
def should_ubatch_with_num_tokens(
should_ubatch: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
vllm_config: VllmConfig,
) -> tuple[bool, Optional[torch.Tensor]]:
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
return DPMetadata.should_ubatch_across_dp(
should_ubatch,
orig_num_tokens_per_ubatch,
padded_num_tokens_per_ubatch,
dp_size,
dp_rank,
)
def check_ubatch_thresholds(
config: ParallelConfig, num_tokens: int, uniform_decode: bool
) -> bool:
if not config.enable_dbo:
return False
if uniform_decode:
return num_tokens >= config.dbo_decode_token_threshold
else:
return num_tokens >= config.dbo_prefill_token_threshold
def get_dp_padding_ubatch(
num_tokens_unpadded: int,
num_tokens_padded: int,
should_attempt_ubatching: bool,
vllm_config: VllmConfig,
) -> tuple[bool, Optional[torch.Tensor]]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do. If this function decides
not to run with microbatching. It will "abort" meaning that no padding
information will be returned to the caller. It will return (False, None)
2. Determines the total number of tokens that each rank will run.
All ranks will be padded out so that the run with the same number
of tokens
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
None if should_ubatch if False
]
"""
assert num_tokens_padded >= num_tokens_unpadded
dp_size = vllm_config.parallel_config.data_parallel_size
if dp_size == 1:
# Early exit.
return False, None
# If this DP rank doesn't want to attempt microbatching
if not should_attempt_ubatching:
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
False, 0, 0, vllm_config
)
assert should_ubatch is False
assert num_tokens_across_dp is None
return should_ubatch, num_tokens_across_dp
# Round up to the next multiple of two for even divisibility
num_tokens_padded = round_up(num_tokens_padded, 2)
num_tokens_per_ubatch = num_tokens_padded // 2
should_ubatch = True
# Sanity Check that the existing padding isn't giving us an empty second
# ubatch. Abort if so
if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded):
logger.debug(
"Empty second µbatch detected: unpadded tokens: %s, padded tokens: %s",
num_tokens_unpadded,
num_tokens_padded,
)
should_ubatch = False
# Note that we compute the number of padded tokens per ubatch
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch, vllm_config
)
if not should_ubatch:
assert num_tokens_across_dp is None
return should_ubatch, num_tokens_across_dp
assert num_tokens_across_dp is not None
max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item())
num_tokens_after_padding = torch.tensor(
[max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32
)
return should_ubatch, num_tokens_after_padding
def create_ubatch_slices(
num_scheduled_tokens: np.ndarray, split_point: int
) -> UBatchSlices:
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
first_ubatch_token_slice = slice(0, split_point)
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
# Determine request slices using exclusive stop semantics
# First ubatch includes requests whose tokens overlap [0, split_point)
first_ubatch_req_stop = int(
np.searchsorted(cu_num_tokens, split_point, side="left")
)
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
# Second ubatch starts at the request that contains the split_point
# or the request starting exactly at split_point (if on boundary)
second_ubatch_req_start = int(
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
)
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
return [
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
]
def ubatch_split(
num_scheduled_tokens_per_request: np.ndarray,
num_tokens_unpadded: int,
num_tokens_padded: int,
uniform_decode: bool,
vllm_config: VllmConfig,
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
None if ubatch_slices is None
]
"""
parallel_config = vllm_config.parallel_config
# Don't bother with the should_ubatch handshaking unless microbatching
# is enabled
if not parallel_config.enable_dbo:
return (None, None)
# Check preconditions for microbatching
should_attempt_ubatching = check_ubatch_thresholds(
parallel_config,
num_tokens_unpadded,
uniform_decode=uniform_decode,
)
# Don't microbatch unless every other DP worker is also microbatching
should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
vllm_config,
)
if not should_ubatch:
return (None, None)
# This doesn't actually pad the ubatch slices. It just initializes the
# split point to the padded value so that padding can be applied
# to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation
assert num_tokens_after_padding is not None
token_split_point = int(num_tokens_after_padding[0].item())
ubatch_slices = create_ubatch_slices(
num_scheduled_tokens_per_request, token_split_point
)
return (ubatch_slices, num_tokens_after_padding)

View File

@ -2,8 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import numpy as np
from typing_extensions import TypeAlias
from vllm.config import ParallelConfig
@dataclass
class UBatchSlice:
@ -24,7 +27,47 @@ class UBatchSlice:
UBatchSlices: TypeAlias = list[UBatchSlice]
def is_second_ubatch_empty(
orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int
def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool:
return (padded_num_tokens // 2) >= orig_num_tokens
def check_ubatch_thresholds(
config: ParallelConfig, num_tokens: int, uniform_decode: bool
) -> bool:
return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch
if not config.enable_dbo:
return False
if uniform_decode:
return num_tokens >= config.dbo_decode_token_threshold
else:
return num_tokens >= config.dbo_prefill_token_threshold
def create_ubatch_slices(
num_scheduled_tokens: np.ndarray, split_point: int
) -> UBatchSlices:
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
first_ubatch_token_slice = slice(0, split_point)
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
# Determine request slices using exclusive stop semantics
# First ubatch includes requests whose tokens overlap [0, split_point)
first_ubatch_req_stop = int(
np.searchsorted(cu_num_tokens, split_point, side="left")
)
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
# Second ubatch starts at the request that contains the split_point
# or the request starting exactly at split_point (if on boundary)
second_ubatch_req_start = int(
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
)
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
return [
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
]