From 2111b4643c1039894dc730a32eec78e937767532 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 6 Oct 2025 18:57:49 -0700 Subject: [PATCH] [Core] Simplify the Dp padding/should ubatch coordination logic (#25768) Signed-off-by: Sage Moore Signed-off-by: mgoin Co-authored-by: mgoin --- .../v1/attention/test_attention_splitting.py | 2 +- vllm/config/parallel.py | 4 + vllm/engine/arg_utils.py | 8 + vllm/envs.py | 7 - vllm/forward_context.py | 129 +---------- vllm/v1/worker/dp_utils.py | 177 +++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 161 ++++---------- vllm/v1/worker/gpu_ubatch_wrapper.py | 15 +- vllm/v1/worker/ubatch_splitting.py | 207 ------------------ vllm/v1/worker/ubatch_utils.py | 49 ++++- 10 files changed, 297 insertions(+), 462 deletions(-) create mode 100644 vllm/v1/worker/dp_utils.py delete mode 100644 vllm/v1/worker/ubatch_splitting.py diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 6335d2a7db5e..1cbd0fe56be6 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import ( split_attn_metadata, split_decodes_and_prefills, ) -from vllm.v1.worker.ubatch_splitting import create_ubatch_slices +from vllm.v1.worker.ubatch_utils import create_ubatch_slices @pytest.fixture diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 649b2434ebbf..7a9be2198f03 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -152,6 +152,10 @@ class ParallelConfig: threshold, microbatching will be used. Otherwise, the request will be processed in a single batch.""" + disable_nccl_for_dp_synchronization: bool = False + """Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py + to use Gloo instead of NCCL for its all reduce""" + ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 942384688184..e01f2d32d914 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -365,6 +365,9 @@ class EngineArgs: enable_dbo: bool = ParallelConfig.enable_dbo dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold + disable_nccl_for_dp_synchronization: bool = ( + ParallelConfig.disable_nccl_for_dp_synchronization + ) eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb expert_placement_strategy: ExpertPlacementStrategy = ( @@ -760,6 +763,10 @@ class EngineArgs: "--dbo-prefill-token-threshold", **parallel_kwargs["dbo_prefill_token_threshold"], ) + parallel_group.add_argument( + "--disable-nccl-for-dp-synchronization", + **parallel_kwargs["disable_nccl_for_dp_synchronization"], + ) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"]) parallel_group.add_argument( @@ -1437,6 +1444,7 @@ class EngineArgs: enable_dbo=self.enable_dbo, dbo_decode_token_threshold=self.dbo_decode_token_threshold, dbo_prefill_token_threshold=self.dbo_prefill_token_threshold, + disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization, enable_eplb=self.enable_eplb, eplb_config=self.eplb_config, expert_placement_strategy=self.expert_placement_strategy, diff --git a/vllm/envs.py b/vllm/envs.py index 2b915b02e48f..1cd03240d7e6 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -95,7 +95,6 @@ if TYPE_CHECKING: VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] - VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False VLLM_DISABLE_PYNCCL: bool = False VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False @@ -830,12 +829,6 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_DISABLED_KERNELS": lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ["VLLM_DISABLED_KERNELS"].split(","), - # Swaps the all reduce backend that we use to coordinate the DP padding - # information from NCCL to gloo. - "VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION": lambda: ( - os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower() - in ("true", "1") - ), # Disable pynccl (using torch.distributed instead) "VLLM_DISABLE_PYNCCL": lambda: ( os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1") diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 26ad37dda776..a6a1e36bfe95 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -8,13 +8,11 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import torch -import torch.distributed as dist import vllm.envs as envs from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty +from vllm.v1.worker.ubatch_utils import UBatchSlices if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -87,129 +85,22 @@ class DPMetadata: # NOTE: local_sizes should only be set by the chunked_sizes context manager local_sizes: Optional[list[int]] = None - @staticmethod - def num_tokens_across_dp( - num_tokens: int, dp_size: int, dp_rank: int - ) -> torch.Tensor: - """ - Gather the num_tokens across all DP ranks and return results in a - CPU tensor of size dp_size. - """ - from vllm.distributed.parallel_state import get_dp_group - - device = current_platform.device_type - group = get_dp_group().device_group - - # Transfering this tensor from GPU to CPU will introduce a GPU sync - # point that could adversely affect performance of vllm with asynch - # scheduling. This environment variable exists to quickly disable - # this optimization if we run into this case. - if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: - logger.info_once( - "Using CPU all reduce to syncronize DP padding between ranks." - ) - device = "cpu" - group = get_dp_group().cpu_group - num_tokens_across_dp = [0] * dp_size - num_tokens_across_dp[dp_rank] = num_tokens - num_tokens_tensor = torch.tensor( - num_tokens_across_dp, device=device, dtype=torch.int32 - ) - dist.all_reduce(num_tokens_tensor, group=group) - return num_tokens_tensor.cpu() - - # Get the cumulative tokens across sequence parallel ranks. - # In this case the input to the MoEs will be distributed w.r.t both - # DP and TP rank. - # When sp_size==1, this is just the cummulative num tokens across DP. - def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor: - num_tokens_across_sp_cpu = ( - self.num_tokens_across_dp_cpu - 1 + sp_size - ) // sp_size - num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size) - return torch.cumsum(num_tokens_across_sp_cpu, dim=0) - - @staticmethod - def should_ubatch_across_dp( - should_ubatch: bool, - orig_num_tokens_per_ubatch: int, - padded_num_tokens_per_ubatch: int, - dp_size: int, - dp_rank: int, - ) -> tuple[bool, Optional[torch.Tensor]]: - """ - 1. Decides if each DP rank is going to microbatch. Either all ranks - run with microbatching or none of them do. If this function decides - not to run with microbatching. It will "abort" meaning that no padding - information will be returned to the caller. It will return (False, None) - - 2. Determines the total number of tokens that each rank will run. - All ranks will be padded out so that the run with the same number - of tokens - - Returns: tuple[ - should_ubatch: Are all DP ranks going to microbatch - num_tokens_after_padding: A tensor containing the total number of - tokens per-microbatch for each DP rank including padding. Will be - None if should_ubatch if False - ] - """ - - device = current_platform.device_type - tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32) - tensor[0][dp_rank] = orig_num_tokens_per_ubatch - tensor[1][dp_rank] = padded_num_tokens_per_ubatch - tensor[2][dp_rank] = 1 if should_ubatch else 0 - - from vllm.distributed.parallel_state import get_dp_group - - dist.all_reduce(tensor, group=get_dp_group().device_group) - - result: bool = bool(torch.all(tensor[2] == 1).item()) - if not result: - return result, None - - orig_num_tokens_tensor = tensor[0, :] - padded_num_tokens_tensor = tensor[1, :] - - orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) - padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) - if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): - logger.debug( - "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens - ) - return False, None - return result, padded_num_tokens_tensor.cpu() - @staticmethod def make( parallel_config: ParallelConfig, - attn_metadata: Any, num_tokens: int, - num_tokens_across_dp_cpu: Optional[torch.Tensor] = None, + num_tokens_across_dp_cpu: torch.Tensor, ) -> "DPMetadata": + assert num_tokens_across_dp_cpu is not None assert parallel_config.data_parallel_size > 1 - dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank - if attn_metadata is not None and hasattr(attn_metadata, "num_prefill_tokens"): - # for v0 attention backends - batchsize = ( - attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens - ) - else: - # for v1 attention backends or no attn_metadata - batchsize = num_tokens + batchsize = num_tokens # If num_tokens_across_dp is None, it will be computed by all_reduce # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize - assert ( - num_tokens_across_dp_cpu is None - or num_tokens_across_dp_cpu[dp_rank] == batchsize - ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" - if num_tokens_across_dp_cpu is None: - num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp( - batchsize, dp_size, dp_rank - ) + assert num_tokens_across_dp_cpu[dp_rank] == batchsize, ( + f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" + ) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu) return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu) @@ -376,11 +267,9 @@ def set_forward_context( if vllm_config.parallel_config.data_parallel_size > 1 and ( attn_metadata is not None or num_tokens is not None ): + assert num_tokens_across_dp is not None dp_metadata = DPMetadata.make( - vllm_config.parallel_config, - attn_metadata, - num_tokens or 0, - num_tokens_across_dp, + vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp ) forward_context = create_forward_context( diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py new file mode 100644 index 000000000000..7a943909a8ba --- /dev/null +++ b/vllm/v1/worker/dp_utils.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import numpy as np +import torch +import torch.distributed as dist + +from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import get_dp_group +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.v1.worker.ubatch_utils import ( + UBatchSlices, + check_ubatch_thresholds, + create_ubatch_slices, + is_second_ubatch_empty, +) + +logger = init_logger(__name__) + + +def _get_device_and_group(parallel_config: ParallelConfig): + device = current_platform.device_type + group = get_dp_group().device_group + + # Transfering this tensor from GPU to CPU will introduce a GPU sync + # point that could adversely affect performance of vllm with asynch + # scheduling. This environment variable exists to quickly disable + # this optimization if we run into this case. + if parallel_config.disable_nccl_for_dp_synchronization: + logger.info_once("Using CPU all reduce to syncronize DP padding between ranks.") + device = "cpu" + group = get_dp_group().cpu_group + return device, group + + +def _run_ar( + should_ubatch: bool, + orig_num_tokens_per_ubatch: int, + padded_num_tokens_per_ubatch: int, + parallel_config: ParallelConfig, +) -> torch.Tensor: + dp_size = parallel_config.data_parallel_size + dp_rank = parallel_config.data_parallel_rank + device, group = _get_device_and_group(parallel_config) + tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32) + tensor[0][dp_rank] = orig_num_tokens_per_ubatch + tensor[1][dp_rank] = padded_num_tokens_per_ubatch + tensor[2][dp_rank] = 1 if should_ubatch else 0 + dist.all_reduce(tensor, group=group) + return tensor + + +def _post_process_ubatch(tensor: torch.Tensor) -> bool: + orig_num_tokens_tensor = tensor[0, :] + padded_num_tokens_tensor = tensor[1, :] + + # First determine if we are going to be ubatching. + should_ubatch: bool = bool(torch.all(tensor[2] == 1).item()) + if not should_ubatch: + return False + # If the DP ranks are planning to ubatch, make sure that + # there are no "empty" second ubatches + orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) + padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) + if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): + logger.debug( + "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens + ) + should_ubatch = False + return should_ubatch + + +def _synchronize_dp_ranks( + num_tokens_unpadded: int, + num_tokens_padded: int, + should_attempt_ubatching: bool, + parallel_config: ParallelConfig, +) -> tuple[bool, Optional[torch.Tensor]]: + """ + 1. Decides if each DP rank is going to microbatch. Either all ranks + run with microbatching or none of them do. + + 2. Determines the total number of tokens that each rank will run. + All ranks will be padded out so that the run with the same number + of tokens + + Returns: tuple[ + should_ubatch: Are all DP ranks going to microbatch + num_tokens_after_padding: A tensor containing the total number of + tokens per-microbatch for each DP rank including padding. + ] + + """ + assert num_tokens_padded >= num_tokens_unpadded + + # First we coordinate between the DP ranks via an All Reduce + # to determine the total number of tokens that each rank + # will run and if we are using ubatching or not. + tensor = _run_ar( + should_ubatch=should_attempt_ubatching, + orig_num_tokens_per_ubatch=num_tokens_unpadded, + padded_num_tokens_per_ubatch=num_tokens_padded, + parallel_config=parallel_config, + ) + + # Ensure that each rank is processing the same nuber of tokens + num_tokens_across_dp = tensor[1, :] + max_num_tokens = int(num_tokens_across_dp.max().item()) + num_tokens_after_padding = torch.tensor( + [max_num_tokens] * len(num_tokens_across_dp), device="cpu", dtype=torch.int32 + ) + + should_ubatch = _post_process_ubatch(tensor) + + return should_ubatch, num_tokens_after_padding + + +def coordinate_batch_across_dp( + num_scheduled_tokens_per_request: np.ndarray, + num_tokens_unpadded: int, + num_tokens_padded: int, + parallel_config: ParallelConfig, + allow_microbatching: bool, + uniform_decode: bool, +) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]: + """ + Coordinates amongst all DP ranks to determine if and how the full batch + should be split into microbatches. + + Returns: tuple[ + ubatch_slices: if this is set then all DP ranks have agreed to + microbatch + num_tokens_after_padding: A tensor containing the total number of + tokens per-microbatch for each DP rank including padding. + ] + + """ + if parallel_config.data_parallel_size == 1: + # Early exit. + return None, None + + # Check preconditions for microbatching + should_attempt_ubatching = check_ubatch_thresholds( + parallel_config, + num_tokens_unpadded, + uniform_decode=uniform_decode, + ) + + # If the caller has explicitly disabled microbatching. + if not allow_microbatching: + should_attempt_ubatching = False + + (should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks( + num_tokens_unpadded, + num_tokens_padded, + should_attempt_ubatching, + parallel_config, + ) + + # Don't microbatch unless every other DP worker is also microbatching + if not should_ubatch: + return (None, num_tokens_after_padding) + + # This doesn't actually pad the ubatch slices. It just initializes the + # split point to the padded value so that padding can be applied + # to the second ubatch in pad_out_ubatch_slice after attention + # metadata creation + assert num_tokens_after_padding is not None + token_split_point = int(num_tokens_after_padding[0].item()) // 2 + + ubatch_slices = create_ubatch_slices( + num_scheduled_tokens_per_request, token_split_point + ) + + return (ubatch_slices, num_tokens_after_padding) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3154395e188f..bd799c06c0eb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -41,7 +41,7 @@ from vllm.distributed.parallel_state import ( is_global_first_rank, prepare_communication_buffer_for_model, ) -from vllm.forward_context import BatchDescriptor, DPMetadata, set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase @@ -131,12 +131,16 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext +from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatch_splitting import check_ubatch_thresholds, ubatch_split -from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices +from vllm.v1.worker.ubatch_utils import ( + UBatchSlice, + UBatchSlices, + check_ubatch_thresholds, +) from vllm.v1.worker.utils import is_residual_scattered_for_sp from .utils import ( @@ -1161,18 +1165,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - num_tokens_padded = num_tokens_unpadded + self.get_local_padding( - num_tokens_unpadded - ) + num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded) uniform_decode = ( max_num_scheduled_tokens == self.uniform_decode_query_len ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - ubatch_slices, num_tokens_after_padding = ubatch_split( + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( num_scheduled_tokens, num_tokens_unpadded, num_tokens_padded, - uniform_decode=uniform_decode, - vllm_config=self.vllm_config, + self.parallel_config, + True, + uniform_decode, ) self.seq_lens.np[:num_reqs] = ( @@ -1405,7 +1408,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): spec_decode_common_attn_metadata, max_num_scheduled_tokens, ubatch_slices, - num_tokens_after_padding, + num_tokens_across_dp, use_cascade_attn, ) @@ -1986,65 +1989,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): log_stats=self.parallel_config.eplb_config.log_balancedness, ) - def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: - """ - Determines the total number of tokens that each rank will run. - All ranks will be padded out so that they run with the same number - of tokens - - Returns: tuple[ - num_pad_tokens: The number of tokens that will be added to the batch - num_tokens_after_padding: A tensor containing the total number of - tokens for each DP rank including padding. - ] - """ - dp_size = self.vllm_config.parallel_config.data_parallel_size - dp_rank = self.vllm_config.parallel_config.data_parallel_rank - - # For DP: Don't pad when setting enforce_eager. - # This lets us set enforce_eager on the prefiller in a P/D setup and - # still use CUDA graphs (enabled by this padding) on the decoder. - # - # TODO(tms) : There are many cases where padding is enabled for - # prefills, causing unnecessary and excessive padding of activations. - - if dp_size == 1 or self.vllm_config.model_config.enforce_eager: - # Early exit. - return 0, None - - num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - num_tokens, dp_size, dp_rank - ) - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor( - [max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32 - ) - return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding - - def get_local_padding(self, num_tokens_unpadded: int) -> int: - num_tokens_padded = num_tokens_unpadded - - if ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1] - ): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) - else: - # Eager mode. - # Pad tokens to multiple of tensor_parallel_size when - # enabled collective fusion for SP - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if ( - self.vllm_config.compilation_config.pass_config.enable_sequence_parallelism - and tp_size > 1 - ): - num_tokens_padded = round_up(num_tokens_unpadded, tp_size) - - num_pad_tokens = num_tokens_padded - num_tokens_unpadded - return num_pad_tokens - # This is where the second ubatch is adjusted to account for the padding. # Should be called after attention metadata creation. This just pads # the second ubatch slice out to the total number of tokens @@ -2127,13 +2071,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _preprocess( self, scheduler_output: "SchedulerOutput", + num_input_tokens: int, # Padded intermediate_tensors: Optional[IntermediateTensors] = None, - ubatch_slices: Optional[UBatchSlices] = None, - num_tokens_after_padding: Optional[torch.Tensor] = None, ) -> tuple[ int, - int, - Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor, @@ -2141,14 +2082,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dict[str, Any], ]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if ubatch_slices: - assert num_tokens_after_padding is not None - num_input_tokens = int(num_tokens_after_padding[0].item() * 2) - self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) - elif ubatch_slices is None: - num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) - num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens) - num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -2235,8 +2168,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return ( num_scheduled_tokens, - num_input_tokens, - num_tokens_after_padding, input_ids, inputs_embeds, positions, @@ -2506,24 +2437,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): spec_decode_common_attn_metadata, max_query_len, ubatch_slices, - num_tokens_after_padding, + num_tokens_across_dp, use_cascade_attn, ) = self._prepare_inputs(scheduler_output) + if ubatch_slices: + assert num_tokens_across_dp is not None + num_input_tokens = int(num_tokens_across_dp[0].item()) + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) + elif num_tokens_across_dp is not None: + num_input_tokens = int(num_tokens_across_dp[0].item()) + else: + num_input_tokens = self._get_num_input_tokens( + scheduler_output.total_num_scheduled_tokens + ) + ( num_scheduled_tokens, - num_input_tokens, - num_tokens_across_dp, input_ids, inputs_embeds, positions, intermediate_tensors, model_kwargs, ) = self._preprocess( - scheduler_output, - intermediate_tensors, - ubatch_slices, - num_tokens_after_padding, + scheduler_output, num_input_tokens, intermediate_tensors ) uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( @@ -2548,11 +2485,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ): cudagraph_runtime_mode = CUDAGraphMode.NONE - # This is currently to get around the assert in the DPMetadata - # where it wants `num_tokens_across_dp` to align with `num_tokens` - if ubatch_slices is not None: - num_input_tokens = ubatch_slices[0].num_tokens - # Run the model. # Use persistent buffers for CUDA graphs. with ( @@ -3329,36 +3261,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) - ubatch_slices = None - num_tokens_after_padding = None - # We currently only microbatch if the number of tokens is # over a certain threshold. - if self.parallel_config.enable_dbo and allow_microbatching: - ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split( - num_scheduled_tokens, - total_num_scheduled_tokens, - total_num_scheduled_tokens, - uniform_decode=uniform_decode, - vllm_config=self.vllm_config, - ) - # Currently when DBO is enabled `ubatch_split` returns - # the num_tokens_after_padding for a single ubatch, but we have 2 - # TODO(sage,lucas): this is cruft that should be addressed in the - # padding refactor. - if ubatch_num_tokens_after_padding is not None: - num_tokens_after_padding = ubatch_num_tokens_after_padding * 2 - - # If we failed to microbatch, currently need to resynchronize - # TODO(lucas,sage): we should be able to avoid this second sync by - # refactoring `get_dp_padding_ubatch` and `get_dp_padding` into - # a single `coordinate_batch_across_dp` function. - if num_tokens_after_padding is None: - num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) - num_tokens_after_padding = num_tokens + num_pad - else: - num_tokens_across_dp = num_tokens_after_padding - num_tokens_after_padding = int(num_tokens_after_padding[0].item()) + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( + num_scheduled_tokens, + total_num_scheduled_tokens, + total_num_scheduled_tokens, + self.vllm_config.parallel_config, + allow_microbatching, + uniform_decode, + ) + num_tokens_after_padding = num_tokens + if num_tokens_across_dp is not None: + num_tokens_after_padding = int(num_tokens_across_dp[0]) attn_metadata: Optional[PerLayerAttnMetadata] = None diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 3bd7c9d538de..fb63fe8d2543 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -13,6 +13,7 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import get_ep_group from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id from vllm.forward_context import ( + DPMetadata, create_forward_context, get_forward_context, override_forward_context, @@ -409,6 +410,18 @@ class UBatchWrapper: # We shouldn't be here unless we are running with multiple DP ranks assert dp_metadata is not None + num_tokens_per_ubatch = ( + ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start + ) + dp_size = self.vllm_config.parallel_config.data_parallel_size + ubatch_num_tokens_across_dp = torch.tensor( + [num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32 + ) + ubatch_dp_metadata = DPMetadata.make( + self.vllm_config.parallel_config, + num_tokens_per_ubatch, + ubatch_num_tokens_across_dp, + ) if ( num_tokens not in self.cudagraphs @@ -422,7 +435,7 @@ class UBatchWrapper: intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, compute_stream=compute_stream, - dp_metadata=dp_metadata, + dp_metadata=ubatch_dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE, ) diff --git a/vllm/v1/worker/ubatch_splitting.py b/vllm/v1/worker/ubatch_splitting.py deleted file mode 100644 index 6723239e8495..000000000000 --- a/vllm/v1/worker/ubatch_splitting.py +++ /dev/null @@ -1,207 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import numpy as np -import torch - -from vllm.config import ParallelConfig, VllmConfig -from vllm.forward_context import DPMetadata -from vllm.logger import init_logger -from vllm.utils import round_up -from vllm.v1.worker.ubatch_utils import ( - UBatchSlice, - UBatchSlices, - is_second_ubatch_empty, -) - -logger = init_logger(__name__) - - -def should_ubatch_with_num_tokens( - should_ubatch: bool, - orig_num_tokens_per_ubatch: int, - padded_num_tokens_per_ubatch: int, - vllm_config: VllmConfig, -) -> tuple[bool, Optional[torch.Tensor]]: - dp_size = vllm_config.parallel_config.data_parallel_size - dp_rank = vllm_config.parallel_config.data_parallel_rank - return DPMetadata.should_ubatch_across_dp( - should_ubatch, - orig_num_tokens_per_ubatch, - padded_num_tokens_per_ubatch, - dp_size, - dp_rank, - ) - - -def check_ubatch_thresholds( - config: ParallelConfig, num_tokens: int, uniform_decode: bool -) -> bool: - if not config.enable_dbo: - return False - if uniform_decode: - return num_tokens >= config.dbo_decode_token_threshold - else: - return num_tokens >= config.dbo_prefill_token_threshold - - -def get_dp_padding_ubatch( - num_tokens_unpadded: int, - num_tokens_padded: int, - should_attempt_ubatching: bool, - vllm_config: VllmConfig, -) -> tuple[bool, Optional[torch.Tensor]]: - """ - 1. Decides if each DP rank is going to microbatch. Either all ranks - run with microbatching or none of them do. If this function decides - not to run with microbatching. It will "abort" meaning that no padding - information will be returned to the caller. It will return (False, None) - - 2. Determines the total number of tokens that each rank will run. - All ranks will be padded out so that the run with the same number - of tokens - - Returns: tuple[ - should_ubatch: Are all DP ranks going to microbatch - num_tokens_after_padding: A tensor containing the total number of - tokens per-microbatch for each DP rank including padding. Will be - None if should_ubatch if False - ] - - """ - assert num_tokens_padded >= num_tokens_unpadded - dp_size = vllm_config.parallel_config.data_parallel_size - if dp_size == 1: - # Early exit. - return False, None - - # If this DP rank doesn't want to attempt microbatching - if not should_attempt_ubatching: - (should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens( - False, 0, 0, vllm_config - ) - assert should_ubatch is False - assert num_tokens_across_dp is None - return should_ubatch, num_tokens_across_dp - - # Round up to the next multiple of two for even divisibility - num_tokens_padded = round_up(num_tokens_padded, 2) - num_tokens_per_ubatch = num_tokens_padded // 2 - should_ubatch = True - - # Sanity Check that the existing padding isn't giving us an empty second - # ubatch. Abort if so - if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded): - logger.debug( - "Empty second µbatch detected: unpadded tokens: %s, padded tokens: %s", - num_tokens_unpadded, - num_tokens_padded, - ) - should_ubatch = False - - # Note that we compute the number of padded tokens per ubatch - (should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens( - should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch, vllm_config - ) - if not should_ubatch: - assert num_tokens_across_dp is None - return should_ubatch, num_tokens_across_dp - - assert num_tokens_across_dp is not None - - max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item()) - num_tokens_after_padding = torch.tensor( - [max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32 - ) - return should_ubatch, num_tokens_after_padding - - -def create_ubatch_slices( - num_scheduled_tokens: np.ndarray, split_point: int -) -> UBatchSlices: - # TODO(lucas): Refactor the gpu_model_runner.py so we can pass - # in cu_num_tokens directly (i.e. query_start_loc) - cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) - np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:]) - - first_ubatch_token_slice = slice(0, split_point) - second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1]) - - # Determine request slices using exclusive stop semantics - # First ubatch includes requests whose tokens overlap [0, split_point) - first_ubatch_req_stop = int( - np.searchsorted(cu_num_tokens, split_point, side="left") - ) - first_ubatch_req_slice = slice(0, first_ubatch_req_stop) - - # Second ubatch starts at the request that contains the split_point - # or the request starting exactly at split_point (if on boundary) - second_ubatch_req_start = int( - np.searchsorted(cu_num_tokens, split_point, side="right") - 1 - ) - second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1) - - return [ - UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), - UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice), - ] - - -def ubatch_split( - num_scheduled_tokens_per_request: np.ndarray, - num_tokens_unpadded: int, - num_tokens_padded: int, - uniform_decode: bool, - vllm_config: VllmConfig, -) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]: - """ - Coordinates amongst all DP ranks to determine if and how the full batch - should be split into microbatches. - - Returns: tuple[ - ubatch_slices: if this is set then all DP ranks have agreed to - microbatch - num_tokens_after_padding: A tensor containing the total number of - tokens per-microbatch for each DP rank including padding. Will be - None if ubatch_slices is None - ] - - """ - parallel_config = vllm_config.parallel_config - # Don't bother with the should_ubatch handshaking unless microbatching - # is enabled - if not parallel_config.enable_dbo: - return (None, None) - - # Check preconditions for microbatching - should_attempt_ubatching = check_ubatch_thresholds( - parallel_config, - num_tokens_unpadded, - uniform_decode=uniform_decode, - ) - - # Don't microbatch unless every other DP worker is also microbatching - should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch( - num_tokens_unpadded, - num_tokens_padded, - should_attempt_ubatching, - vllm_config, - ) - - if not should_ubatch: - return (None, None) - - # This doesn't actually pad the ubatch slices. It just initializes the - # split point to the padded value so that padding can be applied - # to the second ubatch in pad_out_ubatch_slice after attention - # metadata creation - assert num_tokens_after_padding is not None - token_split_point = int(num_tokens_after_padding[0].item()) - - ubatch_slices = create_ubatch_slices( - num_scheduled_tokens_per_request, token_split_point - ) - - return (ubatch_slices, num_tokens_after_padding) diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 2deba16f8a49..ef22977e094b 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -2,8 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +import numpy as np from typing_extensions import TypeAlias +from vllm.config import ParallelConfig + @dataclass class UBatchSlice: @@ -24,7 +27,47 @@ class UBatchSlice: UBatchSlices: TypeAlias = list[UBatchSlice] -def is_second_ubatch_empty( - orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int +def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool: + return (padded_num_tokens // 2) >= orig_num_tokens + + +def check_ubatch_thresholds( + config: ParallelConfig, num_tokens: int, uniform_decode: bool ) -> bool: - return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch + if not config.enable_dbo: + return False + if uniform_decode: + return num_tokens >= config.dbo_decode_token_threshold + else: + return num_tokens >= config.dbo_prefill_token_threshold + + +def create_ubatch_slices( + num_scheduled_tokens: np.ndarray, split_point: int +) -> UBatchSlices: + # TODO(lucas): Refactor the gpu_model_runner.py so we can pass + # in cu_num_tokens directly (i.e. query_start_loc) + cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) + np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:]) + + first_ubatch_token_slice = slice(0, split_point) + second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1]) + + # Determine request slices using exclusive stop semantics + # First ubatch includes requests whose tokens overlap [0, split_point) + first_ubatch_req_stop = int( + np.searchsorted(cu_num_tokens, split_point, side="left") + ) + first_ubatch_req_slice = slice(0, first_ubatch_req_stop) + + # Second ubatch starts at the request that contains the split_point + # or the request starting exactly at split_point (if on boundary) + second_ubatch_req_start = int( + np.searchsorted(cu_num_tokens, split_point, side="right") - 1 + ) + second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1) + + return [ + UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), + UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice), + ]