From b9ff4f2a8dffc84b2ce226e7e98c33756caf098f Mon Sep 17 00:00:00 2001 From: jiangkuaixue123 Date: Tue, 16 Dec 2025 13:04:01 +0800 Subject: [PATCH] [feature] extend DBO to XBO (#30120) Signed-off-by: jiangkuaixue123 Co-authored-by: root --- .../v1/attention/test_attention_splitting.py | 1 + vllm/config/parallel.py | 10 +++ vllm/config/vllm.py | 7 +- vllm/engine/arg_utils.py | 6 ++ vllm/v1/attention/backends/utils.py | 14 ++-- vllm/v1/worker/dp_utils.py | 8 +-- vllm/v1/worker/gpu_model_runner.py | 33 +++++++-- vllm/v1/worker/gpu_ubatch_wrapper.py | 35 ++++----- vllm/v1/worker/ubatch_utils.py | 71 ++++++++++--------- vllm/v1/worker/ubatching.py | 21 ++++-- 10 files changed, 133 insertions(+), 73 deletions(-) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index f08e2f480e30f..734819fcdca83 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -323,6 +323,7 @@ def test_prefill_split_across_ubatches( num_tokens, batch_spec.batch_size, split_point=split_point, + num_ubatches=2, ) assert ubatch_slices is not None and len(ubatch_slices) == 2 diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 1f9dd38ac9114..3fe066ec32505 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -156,6 +156,8 @@ class ParallelConfig: enable_dbo: bool = False """Enable dual batch overlap for the model executor.""" + ubatch_size: int = 0 + """Number of ubatch size.""" dbo_decode_token_threshold: int = 32 """The threshold for dual batch overlap for batches only containing decodes. @@ -325,6 +327,14 @@ class ParallelConfig: including data parallelism.""" return self.world_size * self.data_parallel_size + @property + def use_ubatching(self) -> bool: + return self.enable_dbo or self.ubatch_size > 1 + + @property + def num_ubatches(self) -> int: + return 2 if self.enable_dbo else self.ubatch_size + def get_next_dp_init_port(self) -> int: """ We might need to initialize process groups in multiple diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ace5adc109d86..0439dc52e7e6f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -870,9 +870,12 @@ class VllmConfig: f"cudagraph_mode={self.compilation_config.cudagraph_mode}" ) - if self.parallel_config.enable_dbo: + if self.parallel_config.use_ubatching: a2a_backend = self.parallel_config.all2all_backend - assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], ( + assert a2a_backend in [ + "deepep_low_latency", + "deepep_high_throughput", + ], ( "Microbatching currently only supports the deepep_low_latency and " f"deepep_high_throughput all2all backend. {a2a_backend} is not " "supported. To fix use --all2all-backend=deepep_low_latency or " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3862aa9222446..ca19e468914c7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -408,6 +408,7 @@ class EngineArgs: enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel all2all_backend: str | None = ParallelConfig.all2all_backend enable_dbo: bool = ParallelConfig.enable_dbo + ubatch_size: int = ParallelConfig.ubatch_size dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold disable_nccl_for_dp_synchronization: bool = ( @@ -841,6 +842,10 @@ class EngineArgs: "--all2all-backend", **parallel_kwargs["all2all_backend"] ) parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) + parallel_group.add_argument( + "--ubatch-size", + **parallel_kwargs["ubatch_size"], + ) parallel_group.add_argument( "--dbo-decode-token-threshold", **parallel_kwargs["dbo_decode_token_threshold"], @@ -1557,6 +1562,7 @@ class EngineArgs: enable_expert_parallel=self.enable_expert_parallel, all2all_backend=self.all2all_backend, enable_dbo=self.enable_dbo, + ubatch_size=self.ubatch_size, dbo_decode_token_threshold=self.dbo_decode_token_threshold, dbo_prefill_token_threshold=self.dbo_prefill_token_threshold, disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index da43d87038234..1cbe929fc57a8 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -201,10 +201,11 @@ def _make_metadata_with_slice( ) # NOTE: last token can be outside of the last request if we have CG padding. - # If the "middle" request has tokens in both ubatches, we have to split it. - # If ubatch_slice is the first ubatch then we will be splitting the last - # request. If it's the second microbatch, then we will be splitting the - # first request + # If the request is split across ubatches, we have to adjust the metadata. + # splits_first_request: The first request in this slice is the continuation of + # a request that started in a previous slice. + # splits_last_request: The last request in this slice continues into the + # next slice. splits_first_request = first_tok > start_locs[first_req] splits_last_request = last_tok < start_locs[last_req + 1] - 1 @@ -225,7 +226,10 @@ def _make_metadata_with_slice( seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] if splits_last_request: - tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop + # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate + # the tokens skipped because query_start_loc_cpu might have been modified + # if splits_first_request is True. + tokens_skipped = start_locs[last_req + 1] - token_slice.stop query_start_loc[-1] -= tokens_skipped query_start_loc_cpu[-1] -= tokens_skipped diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index 1b9646e1980a8..82de0cba9194b 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -11,7 +11,7 @@ from vllm.distributed.parallel_state import get_dp_group from vllm.logger import init_logger from vllm.v1.worker.ubatch_utils import ( check_ubatch_thresholds, - is_second_ubatch_empty, + is_last_ubatch_empty, ) logger = init_logger(__name__) @@ -56,7 +56,7 @@ def _run_ar( return tensor -def _post_process_ubatch(tensor: torch.Tensor) -> bool: +def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool: orig_num_tokens_tensor = tensor[0, :] padded_num_tokens_tensor = tensor[1, :] @@ -68,7 +68,7 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool: # there are no "empty" second ubatches orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) - if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): + if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches): logger.debug( "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens ) @@ -146,7 +146,7 @@ def _synchronize_dp_ranks( assert should_attempt_dp_padding == should_dp_pad # Check conditions for microbatching - should_ubatch = _post_process_ubatch(tensor) + should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches) if should_ubatch and not should_dp_pad: logger.debug_once( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 978224faae65e..1aa2ec6bb655c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2987,7 +2987,7 @@ class GPUModelRunner( cascade_attn_prefix_lens = None # Disable cascade attention when using microbatching (DBO) - if self.cascade_attn_enabled and not self.parallel_config.enable_dbo: + if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: # Pre-compute cascade attention prefix lengths cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( num_scheduled_tokens_np, @@ -3028,6 +3028,13 @@ class GPUModelRunner( num_scheduled_tokens_np, num_tokens_padded, num_reqs_padded, + self.parallel_config.num_ubatches, + ) + + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, ) pad_attn = cudagraph_mode == CUDAGraphMode.FULL @@ -3710,11 +3717,14 @@ class GPUModelRunner( # wrap the model with full cudagraph wrapper if needed. cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo: + if ( + cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.use_ubatching + ): self.model = CUDAGraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) - elif self.parallel_config.enable_dbo: + elif self.parallel_config.use_ubatching: if cudagraph_mode.has_full_cudagraphs(): self.model = UBatchWrapper( self.model, self.vllm_config, CUDAGraphMode.FULL, self.device @@ -4095,7 +4105,16 @@ class GPUModelRunner( batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ) ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( - should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded + should_ubatch, + num_scheduled_tokens, + num_tokens_padded, + num_reqs_padded, + self.vllm_config.parallel_config.num_ubatches, + ) + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, ) attn_metadata: PerLayerAttnMetadata | None = None @@ -4644,7 +4663,7 @@ class GPUModelRunner( # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph allow_microbatching = ( - self.parallel_config.enable_dbo + self.parallel_config.use_ubatching and cudagraph_runtime_mode == CUDAGraphMode.FULL and uniform_decode and check_ubatch_thresholds( @@ -4779,8 +4798,8 @@ class GPUModelRunner( if kv_cache_group_id < len(kernel_block_sizes) else None, num_metadata_builders=1 - if not self.parallel_config.enable_dbo - else 2, + if not self.parallel_config.use_ubatching + else self.parallel_config.num_ubatches, ) # Calculate reorder batch threshold (if needed) # Note (tdoublep): do this *after* constructing builders, diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 2ce2b64512560..af09129e67b1e 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -103,8 +103,10 @@ class UBatchWrapper: self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.comm_stream = torch.cuda.Stream(device=device) - # Two ubatch threads plus the main thread - self.ready_barrier = threading.Barrier(3) + # Ubatch threads plus the main thread + self.ready_barrier = threading.Barrier( + self.vllm_config.parallel_config.num_ubatches + 1 + ) self.cudagraphs: dict[int, CUDAGraphMetaData] = {} @@ -309,7 +311,7 @@ class UBatchWrapper: create_forward_context( attn_metadata[i] if attn_metadata is not None else None, self.vllm_config, - dp_metadata=dp_metadata, + dp_metadata=dp_metadata[i], batch_descriptor=batch_descriptor, cudagraph_runtime_mode=cudagraph_runtime_mode, ) @@ -417,18 +419,19 @@ class UBatchWrapper: # We shouldn't be here unless we are running with multiple DP ranks assert dp_metadata is not None - num_tokens_per_ubatch = ( - ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start - ) - dp_size = self.vllm_config.parallel_config.data_parallel_size - ubatch_num_tokens_across_dp = torch.tensor( - [num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32 - ) - ubatch_dp_metadata = DPMetadata.make( - self.vllm_config.parallel_config, - num_tokens_per_ubatch, - ubatch_num_tokens_across_dp, - ) + ubatch_dp_metadata = [] + for ubatch_slice in ubatch_slices: + dp_size = self.vllm_config.parallel_config.data_parallel_size + ubatch_num_tokens_across_dp = torch.tensor( + [ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32 + ) + ubatch_dp_metadata.append( + DPMetadata.make( + self.vllm_config.parallel_config, + ubatch_slice.num_tokens, + ubatch_num_tokens_across_dp, + ) + ) if ( num_tokens not in self.cudagraphs @@ -464,7 +467,7 @@ class UBatchWrapper: intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, compute_stream=compute_stream, - dp_metadata=dp_metadata, + dp_metadata=ubatch_dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE, ) diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 44788476fc9c5..f6889173578d6 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -27,14 +27,16 @@ class UBatchSlice: UBatchSlices: TypeAlias = list[UBatchSlice] -def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool: - return (padded_num_tokens // 2) >= orig_num_tokens +def is_last_ubatch_empty( + orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int +) -> bool: + return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens def check_ubatch_thresholds( config: ParallelConfig, num_tokens: int, uniform_decode: bool ) -> bool: - if not config.enable_dbo: + if not config.use_ubatching: return False if uniform_decode: return num_tokens >= config.dbo_decode_token_threshold @@ -42,21 +44,17 @@ def check_ubatch_thresholds( return num_tokens >= config.dbo_prefill_token_threshold -# This just pads the second ubatch slice out to the total number of tokens +# This pads the last ubatch slice out to the total number of tokens # (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding. def _pad_out_ubatch_slices( ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int ) -> UBatchSlices: - # TODO(lucas): handle empty second ubatch - padded_second_request_slice = slice( - ubatch_slices[1].request_slice.start, num_reqs_padded - ) - padded_second_token_slice = slice( - ubatch_slices[1].token_slice.start, num_total_tokens - ) - return [ - ubatch_slices[0], - UBatchSlice(padded_second_request_slice, padded_second_token_slice), + last_slice = ubatch_slices[-1] + padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded) + padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens) + + return ubatch_slices[:-1] + [ + UBatchSlice(padded_last_request_slice, padded_last_token_slice) ] @@ -65,40 +63,45 @@ def maybe_create_ubatch_slices( num_scheduled_tokens: np.ndarray, num_tokens_padded: int, num_reqs_padded: int, - split_point: int | None = None, + num_ubatches: int, + split_point: list[int] | int | None = None, ) -> tuple[UBatchSlices | None, UBatchSlices | None]: if not should_ubatch: return None, None if split_point is None: - split_point = int(num_tokens_padded) // 2 + split_point = int(num_tokens_padded) // num_ubatches + + token_split_points = [split_point * i for i in range(1, num_ubatches)] # TODO(lucas): Refactor the gpu_model_runner.py so we can pass # in cu_num_tokens directly (i.e. query_start_loc) cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:]) - first_ubatch_token_slice = slice(0, split_point) - second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1]) + ubatch_slices = [] + start_token = 0 - # Determine request slices using exclusive stop semantics - # First ubatch includes requests whose tokens overlap [0, split_point) - first_ubatch_req_stop = int( - np.searchsorted(cu_num_tokens, split_point, side="left") - ) - first_ubatch_req_slice = slice(0, first_ubatch_req_stop) + # Add the end point to the split points to make iteration easier + all_points = token_split_points + [cu_num_tokens[-1]] - # Second ubatch starts at the request that contains the split_point - # or the request starting exactly at split_point (if on boundary) - second_ubatch_req_start = int( - np.searchsorted(cu_num_tokens, split_point, side="right") - 1 - ) - second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1) + for end_token in all_points: + token_slice = slice(start_token, end_token) - ubatch_slices = [ - UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), - UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice), - ] + # Determine request slices using exclusive stop semantics + # Ubatch includes requests whose tokens overlap [start_token, end_token) + + # Start at the request that contains the start_token + # or the request starting exactly at start_token (if on boundary) + req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1) + + # Stop at the request that starts at or after end_token + req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left")) + + req_slice = slice(req_start, req_stop) + ubatch_slices.append(UBatchSlice(req_slice, token_slice)) + + start_token = end_token ubatch_slices_padded = _pad_out_ubatch_slices( ubatch_slices, num_tokens_padded, num_reqs_padded diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index be8326e2fdbc1..e7a947f2ea8ca 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -7,10 +7,15 @@ import torch from vllm import forward_context from vllm.forward_context import ForwardContext +from vllm.logger import init_logger from vllm.utils.torch_utils import current_stream +logger = init_logger(__name__) + _THREAD_ID_TO_CONTEXT: dict = {} -_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None] +# Here we hardcode the number of microbatches to 2 for default. +_NUM_UBATCHES: int = 2 +_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [] class UBatchContext: @@ -48,6 +53,7 @@ class UBatchContext: global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT _THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id _CURRENT_CONTEXTS[self.id] = self + # _NUM_UBATCHES is set in make_ubatch_contexts self.ready_barrier.wait() self.cpu_wait_event.wait() @@ -181,7 +187,7 @@ dbo_switch_to_compute_sync = _register_ubatch_function( def dbo_register_recv_hook(recv_hook): if len(_THREAD_ID_TO_CONTEXT) > 0: ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] - next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] + next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES] next_ctx.recv_hook = recv_hook @@ -202,7 +208,14 @@ def make_ubatch_contexts( ready_barrier: threading.Barrier, schedule: str = "default", ) -> list[UBatchContext]: - assert num_micro_batches == 2, "only been tested with 2 micro-batches" + global _NUM_UBATCHES, _CURRENT_CONTEXTS + assert num_micro_batches > 1, "num_micro_batches must be greater than 1" + + _NUM_UBATCHES = num_micro_batches + # Ensure the global context list is large enough + if len(_CURRENT_CONTEXTS) < num_micro_batches: + _CURRENT_CONTEXTS.extend([None] * (num_micro_batches - len(_CURRENT_CONTEXTS))) + """ Create a context manager for micro-batching synchronization. """ @@ -210,8 +223,6 @@ def make_ubatch_contexts( gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)] gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)] - assert len(forward_contexts) == 2 - ctxs = [] for i in range(num_micro_batches): ctx = UBatchContext(