[feature] extend DBO to XBO (#30120)

Signed-off-by: jiangkuaixue123 <jiangxiaozhou111@163.com>
Co-authored-by: root <root@hk01dgx028.cm.cluster>
This commit is contained in:
jiangkuaixue123 2025-12-16 13:04:01 +08:00 committed by GitHub
parent c881db364e
commit b9ff4f2a8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 133 additions and 73 deletions

View File

@ -323,6 +323,7 @@ def test_prefill_split_across_ubatches(
num_tokens,
batch_spec.batch_size,
split_point=split_point,
num_ubatches=2,
)
assert ubatch_slices is not None and len(ubatch_slices) == 2

View File

@ -156,6 +156,8 @@ class ParallelConfig:
enable_dbo: bool = False
"""Enable dual batch overlap for the model executor."""
ubatch_size: int = 0
"""Number of ubatch size."""
dbo_decode_token_threshold: int = 32
"""The threshold for dual batch overlap for batches only containing decodes.
@ -325,6 +327,14 @@ class ParallelConfig:
including data parallelism."""
return self.world_size * self.data_parallel_size
@property
def use_ubatching(self) -> bool:
return self.enable_dbo or self.ubatch_size > 1
@property
def num_ubatches(self) -> int:
return 2 if self.enable_dbo else self.ubatch_size
def get_next_dp_init_port(self) -> int:
"""
We might need to initialize process groups in multiple

View File

@ -870,9 +870,12 @@ class VllmConfig:
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
)
if self.parallel_config.enable_dbo:
if self.parallel_config.use_ubatching:
a2a_backend = self.parallel_config.all2all_backend
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
assert a2a_backend in [
"deepep_low_latency",
"deepep_high_throughput",
], (
"Microbatching currently only supports the deepep_low_latency and "
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
"supported. To fix use --all2all-backend=deepep_low_latency or "

View File

@ -408,6 +408,7 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
all2all_backend: str | None = ParallelConfig.all2all_backend
enable_dbo: bool = ParallelConfig.enable_dbo
ubatch_size: int = ParallelConfig.ubatch_size
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
disable_nccl_for_dp_synchronization: bool = (
@ -841,6 +842,10 @@ class EngineArgs:
"--all2all-backend", **parallel_kwargs["all2all_backend"]
)
parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
parallel_group.add_argument(
"--ubatch-size",
**parallel_kwargs["ubatch_size"],
)
parallel_group.add_argument(
"--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"],
@ -1557,6 +1562,7 @@ class EngineArgs:
enable_expert_parallel=self.enable_expert_parallel,
all2all_backend=self.all2all_backend,
enable_dbo=self.enable_dbo,
ubatch_size=self.ubatch_size,
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,

View File

@ -201,10 +201,11 @@ def _make_metadata_with_slice(
)
# NOTE: last token can be outside of the last request if we have CG padding.
# If the "middle" request has tokens in both ubatches, we have to split it.
# If ubatch_slice is the first ubatch then we will be splitting the last
# request. If it's the second microbatch, then we will be splitting the
# first request
# If the request is split across ubatches, we have to adjust the metadata.
# splits_first_request: The first request in this slice is the continuation of
# a request that started in a previous slice.
# splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request = first_tok > start_locs[first_req]
splits_last_request = last_tok < start_locs[last_req + 1] - 1
@ -225,7 +226,10 @@ def _make_metadata_with_slice(
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
if splits_last_request:
tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
query_start_loc[-1] -= tokens_skipped
query_start_loc_cpu[-1] -= tokens_skipped

View File

@ -11,7 +11,7 @@ from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
check_ubatch_thresholds,
is_second_ubatch_empty,
is_last_ubatch_empty,
)
logger = init_logger(__name__)
@ -56,7 +56,7 @@ def _run_ar(
return tensor
def _post_process_ubatch(tensor: torch.Tensor) -> bool:
def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool:
orig_num_tokens_tensor = tensor[0, :]
padded_num_tokens_tensor = tensor[1, :]
@ -68,7 +68,7 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool:
# there are no "empty" second ubatches
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches):
logger.debug(
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
)
@ -146,7 +146,7 @@ def _synchronize_dp_ranks(
assert should_attempt_dp_padding == should_dp_pad
# Check conditions for microbatching
should_ubatch = _post_process_ubatch(tensor)
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
if should_ubatch and not should_dp_pad:
logger.debug_once(

View File

@ -2987,7 +2987,7 @@ class GPUModelRunner(
cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO)
if self.cascade_attn_enabled and not self.parallel_config.enable_dbo:
if self.cascade_attn_enabled and not self.parallel_config.use_ubatching:
# Pre-compute cascade attention prefix lengths
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
num_scheduled_tokens_np,
@ -3028,6 +3028,13 @@ class GPUModelRunner(
num_scheduled_tokens_np,
num_tokens_padded,
num_reqs_padded,
self.parallel_config.num_ubatches,
)
logger.debug(
"ubatch_slices: %s, ubatch_slices_padded: %s",
ubatch_slices,
ubatch_slices_padded,
)
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
@ -3710,11 +3717,14 @@ class GPUModelRunner(
# wrap the model with full cudagraph wrapper if needed.
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo:
if (
cudagraph_mode.has_full_cudagraphs()
and not self.parallel_config.use_ubatching
):
self.model = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
)
elif self.parallel_config.enable_dbo:
elif self.parallel_config.use_ubatching:
if cudagraph_mode.has_full_cudagraphs():
self.model = UBatchWrapper(
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
@ -4095,7 +4105,16 @@ class GPUModelRunner(
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded
should_ubatch,
num_scheduled_tokens,
num_tokens_padded,
num_reqs_padded,
self.vllm_config.parallel_config.num_ubatches,
)
logger.debug(
"ubatch_slices: %s, ubatch_slices_padded: %s",
ubatch_slices,
ubatch_slices_padded,
)
attn_metadata: PerLayerAttnMetadata | None = None
@ -4644,7 +4663,7 @@ class GPUModelRunner(
# is above the threshold. Otherwise we just capture a non-ubatched
# version of the graph
allow_microbatching = (
self.parallel_config.enable_dbo
self.parallel_config.use_ubatching
and cudagraph_runtime_mode == CUDAGraphMode.FULL
and uniform_decode
and check_ubatch_thresholds(
@ -4779,8 +4798,8 @@ class GPUModelRunner(
if kv_cache_group_id < len(kernel_block_sizes)
else None,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
if not self.parallel_config.use_ubatching
else self.parallel_config.num_ubatches,
)
# Calculate reorder batch threshold (if needed)
# Note (tdoublep): do this *after* constructing builders,

View File

@ -103,8 +103,10 @@ class UBatchWrapper:
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.comm_stream = torch.cuda.Stream(device=device)
# Two ubatch threads plus the main thread
self.ready_barrier = threading.Barrier(3)
# Ubatch threads plus the main thread
self.ready_barrier = threading.Barrier(
self.vllm_config.parallel_config.num_ubatches + 1
)
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
@ -309,7 +311,7 @@ class UBatchWrapper:
create_forward_context(
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config,
dp_metadata=dp_metadata,
dp_metadata=dp_metadata[i],
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_runtime_mode,
)
@ -417,18 +419,19 @@ class UBatchWrapper:
# We shouldn't be here unless we are running with multiple DP ranks
assert dp_metadata is not None
num_tokens_per_ubatch = (
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
)
ubatch_dp_metadata = []
for ubatch_slice in ubatch_slices:
dp_size = self.vllm_config.parallel_config.data_parallel_size
ubatch_num_tokens_across_dp = torch.tensor(
[num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32
[ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
)
ubatch_dp_metadata = DPMetadata.make(
ubatch_dp_metadata.append(
DPMetadata.make(
self.vllm_config.parallel_config,
num_tokens_per_ubatch,
ubatch_slice.num_tokens,
ubatch_num_tokens_across_dp,
)
)
if (
num_tokens not in self.cudagraphs
@ -464,7 +467,7 @@ class UBatchWrapper:
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
dp_metadata=ubatch_dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)

View File

@ -27,14 +27,16 @@ class UBatchSlice:
UBatchSlices: TypeAlias = list[UBatchSlice]
def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool:
return (padded_num_tokens // 2) >= orig_num_tokens
def is_last_ubatch_empty(
orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int
) -> bool:
return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens
def check_ubatch_thresholds(
config: ParallelConfig, num_tokens: int, uniform_decode: bool
) -> bool:
if not config.enable_dbo:
if not config.use_ubatching:
return False
if uniform_decode:
return num_tokens >= config.dbo_decode_token_threshold
@ -42,21 +44,17 @@ def check_ubatch_thresholds(
return num_tokens >= config.dbo_prefill_token_threshold
# This just pads the second ubatch slice out to the total number of tokens
# This pads the last ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slices(
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
) -> UBatchSlices:
# TODO(lucas): handle empty second ubatch
padded_second_request_slice = slice(
ubatch_slices[1].request_slice.start, num_reqs_padded
)
padded_second_token_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
return [
ubatch_slices[0],
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
last_slice = ubatch_slices[-1]
padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded)
padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens)
return ubatch_slices[:-1] + [
UBatchSlice(padded_last_request_slice, padded_last_token_slice)
]
@ -65,40 +63,45 @@ def maybe_create_ubatch_slices(
num_scheduled_tokens: np.ndarray,
num_tokens_padded: int,
num_reqs_padded: int,
split_point: int | None = None,
num_ubatches: int,
split_point: list[int] | int | None = None,
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
if not should_ubatch:
return None, None
if split_point is None:
split_point = int(num_tokens_padded) // 2
split_point = int(num_tokens_padded) // num_ubatches
token_split_points = [split_point * i for i in range(1, num_ubatches)]
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
first_ubatch_token_slice = slice(0, split_point)
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
ubatch_slices = []
start_token = 0
# Add the end point to the split points to make iteration easier
all_points = token_split_points + [cu_num_tokens[-1]]
for end_token in all_points:
token_slice = slice(start_token, end_token)
# Determine request slices using exclusive stop semantics
# First ubatch includes requests whose tokens overlap [0, split_point)
first_ubatch_req_stop = int(
np.searchsorted(cu_num_tokens, split_point, side="left")
)
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
# Ubatch includes requests whose tokens overlap [start_token, end_token)
# Second ubatch starts at the request that contains the split_point
# or the request starting exactly at split_point (if on boundary)
second_ubatch_req_start = int(
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
)
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
# Start at the request that contains the start_token
# or the request starting exactly at start_token (if on boundary)
req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1)
ubatch_slices = [
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
]
# Stop at the request that starts at or after end_token
req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left"))
req_slice = slice(req_start, req_stop)
ubatch_slices.append(UBatchSlice(req_slice, token_slice))
start_token = end_token
ubatch_slices_padded = _pad_out_ubatch_slices(
ubatch_slices, num_tokens_padded, num_reqs_padded

View File

@ -7,10 +7,15 @@ import torch
from vllm import forward_context
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.utils.torch_utils import current_stream
logger = init_logger(__name__)
_THREAD_ID_TO_CONTEXT: dict = {}
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None]
# Here we hardcode the number of microbatches to 2 for default.
_NUM_UBATCHES: int = 2
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = []
class UBatchContext:
@ -48,6 +53,7 @@ class UBatchContext:
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
_CURRENT_CONTEXTS[self.id] = self
# _NUM_UBATCHES is set in make_ubatch_contexts
self.ready_barrier.wait()
self.cpu_wait_event.wait()
@ -181,7 +187,7 @@ dbo_switch_to_compute_sync = _register_ubatch_function(
def dbo_register_recv_hook(recv_hook):
if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2]
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES]
next_ctx.recv_hook = recv_hook
@ -202,7 +208,14 @@ def make_ubatch_contexts(
ready_barrier: threading.Barrier,
schedule: str = "default",
) -> list[UBatchContext]:
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
global _NUM_UBATCHES, _CURRENT_CONTEXTS
assert num_micro_batches > 1, "num_micro_batches must be greater than 1"
_NUM_UBATCHES = num_micro_batches
# Ensure the global context list is large enough
if len(_CURRENT_CONTEXTS) < num_micro_batches:
_CURRENT_CONTEXTS.extend([None] * (num_micro_batches - len(_CURRENT_CONTEXTS)))
"""
Create a context manager for micro-batching synchronization.
"""
@ -210,8 +223,6 @@ def make_ubatch_contexts(
gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
assert len(forward_contexts) == 2
ctxs = []
for i in range(num_micro_batches):
ctx = UBatchContext(