diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 37831e02f53f4..36f3062a9e3a0 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -12,6 +12,7 @@ import torch import vllm.envs as envs from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger +from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ubatch_utils import UBatchSlices if TYPE_CHECKING: @@ -278,7 +279,19 @@ def set_forward_context( if vllm_config.parallel_config.data_parallel_size > 1 and ( attn_metadata is not None or num_tokens is not None ): - assert num_tokens_across_dp is not None + # If num_tokens_across_dp hasn't already been initialized, then + # initialize it here. Both DP padding and Microbatching will be + # disabled. + if num_tokens_across_dp is None: + assert ubatch_slices is None + assert num_tokens is not None + _, num_tokens_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens, + parallel_config=vllm_config.parallel_config, + allow_microbatching=False, + allow_dp_padding=False, + ) + assert num_tokens_across_dp is not None dp_metadata = DPMetadata.make( vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp ) diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index 7a943909a8ba9..1bb6a6f4d05f7 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -7,7 +7,7 @@ import torch import torch.distributed as dist from vllm.config import ParallelConfig -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.worker.ubatch_utils import ( @@ -37,6 +37,7 @@ def _get_device_and_group(parallel_config: ParallelConfig): def _run_ar( should_ubatch: bool, + should_dp_pad: bool, orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int, parallel_config: ParallelConfig, @@ -44,10 +45,11 @@ def _run_ar( dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank device, group = _get_device_and_group(parallel_config) - tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32) + tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32) tensor[0][dp_rank] = orig_num_tokens_per_ubatch tensor[1][dp_rank] = padded_num_tokens_per_ubatch tensor[2][dp_rank] = 1 if should_ubatch else 0 + tensor[3][dp_rank] = 1 if should_dp_pad else 0 dist.all_reduce(tensor, group=group) return tensor @@ -72,10 +74,26 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool: return should_ubatch +def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor: + num_tokens_across_dp = tensor[1, :] + if should_dp_pad: + # If DP padding is enabled, ensure that each rank is processing the same number + # of tokens + max_num_tokens = int(num_tokens_across_dp.max().item()) + return torch.tensor( + [max_num_tokens] * len(num_tokens_across_dp), + device="cpu", + dtype=torch.int32, + ) + else: + return num_tokens_across_dp.cpu() + + def _synchronize_dp_ranks( num_tokens_unpadded: int, num_tokens_padded: int, should_attempt_ubatching: bool, + should_attempt_dp_padding: bool, parallel_config: ParallelConfig, ) -> tuple[bool, Optional[torch.Tensor]]: """ @@ -83,57 +101,88 @@ def _synchronize_dp_ranks( run with microbatching or none of them do. 2. Determines the total number of tokens that each rank will run. - All ranks will be padded out so that the run with the same number - of tokens + When running microbatched or if should_attempt_dp_padding is True, all + ranks will be padded out so that the run with the same number of tokens Returns: tuple[ should_ubatch: Are all DP ranks going to microbatch num_tokens_after_padding: A tensor containing the total number of - tokens per-microbatch for each DP rank including padding. + tokens per-microbatch for each DP rank including any DP padding. ] """ assert num_tokens_padded >= num_tokens_unpadded - # First we coordinate between the DP ranks via an All Reduce + # Coordinate between the DP ranks via an All Reduce # to determine the total number of tokens that each rank # will run and if we are using ubatching or not. tensor = _run_ar( should_ubatch=should_attempt_ubatching, + should_dp_pad=should_attempt_dp_padding, orig_num_tokens_per_ubatch=num_tokens_unpadded, padded_num_tokens_per_ubatch=num_tokens_padded, parallel_config=parallel_config, ) - # Ensure that each rank is processing the same nuber of tokens - num_tokens_across_dp = tensor[1, :] - max_num_tokens = int(num_tokens_across_dp.max().item()) - num_tokens_after_padding = torch.tensor( - [max_num_tokens] * len(num_tokens_across_dp), device="cpu", dtype=torch.int32 - ) + should_dp_pad = bool(torch.all(tensor[3] == 1).item()) + # DP ranks should all have the same value for should_attempt_dp_padding. + assert should_attempt_dp_padding == should_dp_pad + + # Check conditions for microbatching should_ubatch = _post_process_ubatch(tensor) + if should_ubatch and not should_dp_pad: + if is_global_first_rank(): + logger.debug( + "Microbatching has been triggered and requires DP padding. " + "Enabling DP padding even though it has been explicitly " + "disabled." + ) + should_dp_pad = True + + # Pad all DP ranks up to the maximum token count across ranks if + # should_dp_pad is True + num_tokens_after_padding = _post_process_dp_padding( + tensor, + should_dp_pad, + ) + return should_ubatch, num_tokens_after_padding def coordinate_batch_across_dp( - num_scheduled_tokens_per_request: np.ndarray, num_tokens_unpadded: int, - num_tokens_padded: int, - parallel_config: ParallelConfig, allow_microbatching: bool, - uniform_decode: bool, + allow_dp_padding: bool, + parallel_config: ParallelConfig, + num_tokens_padded: Optional[int] = None, + uniform_decode: Optional[bool] = None, + num_scheduled_tokens_per_request: Optional[np.ndarray] = None, ) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]: """ Coordinates amongst all DP ranks to determine if and how the full batch should be split into microbatches. + Args: + num_tokens_unpadded: Number of tokens without accounting for padding + allow_microbatching: If microbatching should be attempted + allow_dp_padding: If all DP ranks should be padded up to the same value + parallel_config: The parallel config + num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs, + TP, etc) + uniform_decode: Only used if allow_microbatching is True. True if the batch + only contains single token decodes + num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The + number of tokens per request. + Returns: tuple[ ubatch_slices: if this is set then all DP ranks have agreed to microbatch num_tokens_after_padding: A tensor containing the total number of - tokens per-microbatch for each DP rank including padding. + tokens per-microbatch for each DP rank including padding. Will be + padded up to the max value across all DP ranks when allow_dp_padding + is True. ] """ @@ -141,21 +190,25 @@ def coordinate_batch_across_dp( # Early exit. return None, None - # Check preconditions for microbatching - should_attempt_ubatching = check_ubatch_thresholds( - parallel_config, - num_tokens_unpadded, - uniform_decode=uniform_decode, - ) + # If the caller has explicitly enabled microbatching. + should_attempt_ubatching = False + if allow_microbatching: + # Check preconditions for microbatching + assert uniform_decode is not None + should_attempt_ubatching = check_ubatch_thresholds( + parallel_config, + num_tokens_unpadded, + uniform_decode=uniform_decode, + ) - # If the caller has explicitly disabled microbatching. - if not allow_microbatching: - should_attempt_ubatching = False + if num_tokens_padded is None: + num_tokens_padded = num_tokens_unpadded (should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks( num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching, + allow_dp_padding, parallel_config, ) @@ -170,6 +223,7 @@ def coordinate_batch_across_dp( assert num_tokens_after_padding is not None token_split_point = int(num_tokens_after_padding[0].item()) // 2 + assert num_scheduled_tokens_per_request is not None ubatch_slices = create_ubatch_slices( num_scheduled_tokens_per_request, token_split_point ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0772cbeeff21d..f5b73e46a239c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1178,13 +1178,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): uniform_decode = ( max_num_scheduled_tokens == self.uniform_decode_query_len ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + + # Disable DP padding when running eager to avoid excessive padding when + # running prefills. This lets us set enforce_eager on the prefiller in + # a P/D setup and still use CUDA graphs (enabled by this padding) on the + # decoder. + allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( - num_scheduled_tokens, - num_tokens_unpadded, - num_tokens_padded, - self.parallel_config, - True, - uniform_decode, + num_tokens_unpadded=num_tokens_unpadded, + parallel_config=self.parallel_config, + allow_microbatching=True, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=num_tokens_padded, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens, ) self.seq_lens.np[:num_reqs] = ( @@ -2436,12 +2444,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): use_cascade_attn, ) = self._prepare_inputs(scheduler_output) + dp_rank = self.parallel_config.data_parallel_rank if ubatch_slices: assert num_tokens_across_dp is not None - num_input_tokens = int(num_tokens_across_dp[0].item()) + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) elif num_tokens_across_dp is not None: - num_input_tokens = int(num_tokens_across_dp[0].item()) + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) else: num_input_tokens = self._get_num_input_tokens( scheduler_output.total_num_scheduled_tokens @@ -3256,19 +3265,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) + # Disable DP padding when running eager + allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + # We currently only microbatch if the number of tokens is # over a certain threshold. ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( - num_scheduled_tokens, - total_num_scheduled_tokens, - total_num_scheduled_tokens, - self.vllm_config.parallel_config, - allow_microbatching, - uniform_decode, + num_tokens_unpadded=total_num_scheduled_tokens, + parallel_config=self.vllm_config.parallel_config, + allow_microbatching=allow_microbatching, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=total_num_scheduled_tokens, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens, ) num_tokens_after_padding = num_tokens if num_tokens_across_dp is not None: - num_tokens_after_padding = int(num_tokens_across_dp[0]) + dp_rank = self.parallel_config.data_parallel_rank + num_tokens_after_padding = int(num_tokens_across_dp[dp_rank]) attn_metadata: Optional[PerLayerAttnMetadata] = None