mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 21:15:01 +08:00
[feature] extend DBO to XBO (#30120)
Signed-off-by: jiangkuaixue123 <jiangxiaozhou111@163.com> Co-authored-by: root <root@hk01dgx028.cm.cluster>
This commit is contained in:
parent
c881db364e
commit
b9ff4f2a8d
@ -323,6 +323,7 @@ def test_prefill_split_across_ubatches(
|
||||
num_tokens,
|
||||
batch_spec.batch_size,
|
||||
split_point=split_point,
|
||||
num_ubatches=2,
|
||||
)
|
||||
assert ubatch_slices is not None and len(ubatch_slices) == 2
|
||||
|
||||
|
||||
@ -156,6 +156,8 @@ class ParallelConfig:
|
||||
|
||||
enable_dbo: bool = False
|
||||
"""Enable dual batch overlap for the model executor."""
|
||||
ubatch_size: int = 0
|
||||
"""Number of ubatch size."""
|
||||
|
||||
dbo_decode_token_threshold: int = 32
|
||||
"""The threshold for dual batch overlap for batches only containing decodes.
|
||||
@ -325,6 +327,14 @@ class ParallelConfig:
|
||||
including data parallelism."""
|
||||
return self.world_size * self.data_parallel_size
|
||||
|
||||
@property
|
||||
def use_ubatching(self) -> bool:
|
||||
return self.enable_dbo or self.ubatch_size > 1
|
||||
|
||||
@property
|
||||
def num_ubatches(self) -> int:
|
||||
return 2 if self.enable_dbo else self.ubatch_size
|
||||
|
||||
def get_next_dp_init_port(self) -> int:
|
||||
"""
|
||||
We might need to initialize process groups in multiple
|
||||
|
||||
@ -870,9 +870,12 @@ class VllmConfig:
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
)
|
||||
|
||||
if self.parallel_config.enable_dbo:
|
||||
if self.parallel_config.use_ubatching:
|
||||
a2a_backend = self.parallel_config.all2all_backend
|
||||
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
|
||||
assert a2a_backend in [
|
||||
"deepep_low_latency",
|
||||
"deepep_high_throughput",
|
||||
], (
|
||||
"Microbatching currently only supports the deepep_low_latency and "
|
||||
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
|
||||
"supported. To fix use --all2all-backend=deepep_low_latency or "
|
||||
|
||||
@ -408,6 +408,7 @@ class EngineArgs:
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
all2all_backend: str | None = ParallelConfig.all2all_backend
|
||||
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||
ubatch_size: int = ParallelConfig.ubatch_size
|
||||
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
|
||||
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
|
||||
disable_nccl_for_dp_synchronization: bool = (
|
||||
@ -841,6 +842,10 @@ class EngineArgs:
|
||||
"--all2all-backend", **parallel_kwargs["all2all_backend"]
|
||||
)
|
||||
parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
|
||||
parallel_group.add_argument(
|
||||
"--ubatch-size",
|
||||
**parallel_kwargs["ubatch_size"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--dbo-decode-token-threshold",
|
||||
**parallel_kwargs["dbo_decode_token_threshold"],
|
||||
@ -1557,6 +1562,7 @@ class EngineArgs:
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
all2all_backend=self.all2all_backend,
|
||||
enable_dbo=self.enable_dbo,
|
||||
ubatch_size=self.ubatch_size,
|
||||
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
||||
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
|
||||
disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
|
||||
|
||||
@ -201,10 +201,11 @@ def _make_metadata_with_slice(
|
||||
)
|
||||
# NOTE: last token can be outside of the last request if we have CG padding.
|
||||
|
||||
# If the "middle" request has tokens in both ubatches, we have to split it.
|
||||
# If ubatch_slice is the first ubatch then we will be splitting the last
|
||||
# request. If it's the second microbatch, then we will be splitting the
|
||||
# first request
|
||||
# If the request is split across ubatches, we have to adjust the metadata.
|
||||
# splits_first_request: The first request in this slice is the continuation of
|
||||
# a request that started in a previous slice.
|
||||
# splits_last_request: The last request in this slice continues into the
|
||||
# next slice.
|
||||
splits_first_request = first_tok > start_locs[first_req]
|
||||
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
||||
|
||||
@ -225,7 +226,10 @@ def _make_metadata_with_slice(
|
||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||
|
||||
if splits_last_request:
|
||||
tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
|
||||
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
|
||||
# the tokens skipped because query_start_loc_cpu might have been modified
|
||||
# if splits_first_request is True.
|
||||
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
|
||||
query_start_loc[-1] -= tokens_skipped
|
||||
query_start_loc_cpu[-1] -= tokens_skipped
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
check_ubatch_thresholds,
|
||||
is_second_ubatch_empty,
|
||||
is_last_ubatch_empty,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -56,7 +56,7 @@ def _run_ar(
|
||||
return tensor
|
||||
|
||||
|
||||
def _post_process_ubatch(tensor: torch.Tensor) -> bool:
|
||||
def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool:
|
||||
orig_num_tokens_tensor = tensor[0, :]
|
||||
padded_num_tokens_tensor = tensor[1, :]
|
||||
|
||||
@ -68,7 +68,7 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool:
|
||||
# there are no "empty" second ubatches
|
||||
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
|
||||
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
|
||||
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
|
||||
if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches):
|
||||
logger.debug(
|
||||
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
|
||||
)
|
||||
@ -146,7 +146,7 @@ def _synchronize_dp_ranks(
|
||||
assert should_attempt_dp_padding == should_dp_pad
|
||||
|
||||
# Check conditions for microbatching
|
||||
should_ubatch = _post_process_ubatch(tensor)
|
||||
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
|
||||
|
||||
if should_ubatch and not should_dp_pad:
|
||||
logger.debug_once(
|
||||
|
||||
@ -2987,7 +2987,7 @@ class GPUModelRunner(
|
||||
|
||||
cascade_attn_prefix_lens = None
|
||||
# Disable cascade attention when using microbatching (DBO)
|
||||
if self.cascade_attn_enabled and not self.parallel_config.enable_dbo:
|
||||
if self.cascade_attn_enabled and not self.parallel_config.use_ubatching:
|
||||
# Pre-compute cascade attention prefix lengths
|
||||
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
|
||||
num_scheduled_tokens_np,
|
||||
@ -3028,6 +3028,13 @@ class GPUModelRunner(
|
||||
num_scheduled_tokens_np,
|
||||
num_tokens_padded,
|
||||
num_reqs_padded,
|
||||
self.parallel_config.num_ubatches,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"ubatch_slices: %s, ubatch_slices_padded: %s",
|
||||
ubatch_slices,
|
||||
ubatch_slices_padded,
|
||||
)
|
||||
|
||||
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
|
||||
@ -3710,11 +3717,14 @@ class GPUModelRunner(
|
||||
# wrap the model with full cudagraph wrapper if needed.
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
assert cudagraph_mode is not None
|
||||
if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo:
|
||||
if (
|
||||
cudagraph_mode.has_full_cudagraphs()
|
||||
and not self.parallel_config.use_ubatching
|
||||
):
|
||||
self.model = CUDAGraphWrapper(
|
||||
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
||||
)
|
||||
elif self.parallel_config.enable_dbo:
|
||||
elif self.parallel_config.use_ubatching:
|
||||
if cudagraph_mode.has_full_cudagraphs():
|
||||
self.model = UBatchWrapper(
|
||||
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
|
||||
@ -4095,7 +4105,16 @@ class GPUModelRunner(
|
||||
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
|
||||
)
|
||||
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
|
||||
should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded
|
||||
should_ubatch,
|
||||
num_scheduled_tokens,
|
||||
num_tokens_padded,
|
||||
num_reqs_padded,
|
||||
self.vllm_config.parallel_config.num_ubatches,
|
||||
)
|
||||
logger.debug(
|
||||
"ubatch_slices: %s, ubatch_slices_padded: %s",
|
||||
ubatch_slices,
|
||||
ubatch_slices_padded,
|
||||
)
|
||||
|
||||
attn_metadata: PerLayerAttnMetadata | None = None
|
||||
@ -4644,7 +4663,7 @@ class GPUModelRunner(
|
||||
# is above the threshold. Otherwise we just capture a non-ubatched
|
||||
# version of the graph
|
||||
allow_microbatching = (
|
||||
self.parallel_config.enable_dbo
|
||||
self.parallel_config.use_ubatching
|
||||
and cudagraph_runtime_mode == CUDAGraphMode.FULL
|
||||
and uniform_decode
|
||||
and check_ubatch_thresholds(
|
||||
@ -4779,8 +4798,8 @@ class GPUModelRunner(
|
||||
if kv_cache_group_id < len(kernel_block_sizes)
|
||||
else None,
|
||||
num_metadata_builders=1
|
||||
if not self.parallel_config.enable_dbo
|
||||
else 2,
|
||||
if not self.parallel_config.use_ubatching
|
||||
else self.parallel_config.num_ubatches,
|
||||
)
|
||||
# Calculate reorder batch threshold (if needed)
|
||||
# Note (tdoublep): do this *after* constructing builders,
|
||||
|
||||
@ -103,8 +103,10 @@ class UBatchWrapper:
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.comm_stream = torch.cuda.Stream(device=device)
|
||||
# Two ubatch threads plus the main thread
|
||||
self.ready_barrier = threading.Barrier(3)
|
||||
# Ubatch threads plus the main thread
|
||||
self.ready_barrier = threading.Barrier(
|
||||
self.vllm_config.parallel_config.num_ubatches + 1
|
||||
)
|
||||
|
||||
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
|
||||
|
||||
@ -309,7 +311,7 @@ class UBatchWrapper:
|
||||
create_forward_context(
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
self.vllm_config,
|
||||
dp_metadata=dp_metadata,
|
||||
dp_metadata=dp_metadata[i],
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
)
|
||||
@ -417,18 +419,19 @@ class UBatchWrapper:
|
||||
|
||||
# We shouldn't be here unless we are running with multiple DP ranks
|
||||
assert dp_metadata is not None
|
||||
num_tokens_per_ubatch = (
|
||||
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
|
||||
)
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
ubatch_num_tokens_across_dp = torch.tensor(
|
||||
[num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32
|
||||
)
|
||||
ubatch_dp_metadata = DPMetadata.make(
|
||||
self.vllm_config.parallel_config,
|
||||
num_tokens_per_ubatch,
|
||||
ubatch_num_tokens_across_dp,
|
||||
)
|
||||
ubatch_dp_metadata = []
|
||||
for ubatch_slice in ubatch_slices:
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
ubatch_num_tokens_across_dp = torch.tensor(
|
||||
[ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
|
||||
)
|
||||
ubatch_dp_metadata.append(
|
||||
DPMetadata.make(
|
||||
self.vllm_config.parallel_config,
|
||||
ubatch_slice.num_tokens,
|
||||
ubatch_num_tokens_across_dp,
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
num_tokens not in self.cudagraphs
|
||||
@ -464,7 +467,7 @@ class UBatchWrapper:
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=dp_metadata,
|
||||
dp_metadata=ubatch_dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
@ -27,14 +27,16 @@ class UBatchSlice:
|
||||
UBatchSlices: TypeAlias = list[UBatchSlice]
|
||||
|
||||
|
||||
def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool:
|
||||
return (padded_num_tokens // 2) >= orig_num_tokens
|
||||
def is_last_ubatch_empty(
|
||||
orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int
|
||||
) -> bool:
|
||||
return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens
|
||||
|
||||
|
||||
def check_ubatch_thresholds(
|
||||
config: ParallelConfig, num_tokens: int, uniform_decode: bool
|
||||
) -> bool:
|
||||
if not config.enable_dbo:
|
||||
if not config.use_ubatching:
|
||||
return False
|
||||
if uniform_decode:
|
||||
return num_tokens >= config.dbo_decode_token_threshold
|
||||
@ -42,21 +44,17 @@ def check_ubatch_thresholds(
|
||||
return num_tokens >= config.dbo_prefill_token_threshold
|
||||
|
||||
|
||||
# This just pads the second ubatch slice out to the total number of tokens
|
||||
# This pads the last ubatch slice out to the total number of tokens
|
||||
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
|
||||
def _pad_out_ubatch_slices(
|
||||
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
|
||||
) -> UBatchSlices:
|
||||
# TODO(lucas): handle empty second ubatch
|
||||
padded_second_request_slice = slice(
|
||||
ubatch_slices[1].request_slice.start, num_reqs_padded
|
||||
)
|
||||
padded_second_token_slice = slice(
|
||||
ubatch_slices[1].token_slice.start, num_total_tokens
|
||||
)
|
||||
return [
|
||||
ubatch_slices[0],
|
||||
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
|
||||
last_slice = ubatch_slices[-1]
|
||||
padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded)
|
||||
padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens)
|
||||
|
||||
return ubatch_slices[:-1] + [
|
||||
UBatchSlice(padded_last_request_slice, padded_last_token_slice)
|
||||
]
|
||||
|
||||
|
||||
@ -65,40 +63,45 @@ def maybe_create_ubatch_slices(
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_tokens_padded: int,
|
||||
num_reqs_padded: int,
|
||||
split_point: int | None = None,
|
||||
num_ubatches: int,
|
||||
split_point: list[int] | int | None = None,
|
||||
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
|
||||
if not should_ubatch:
|
||||
return None, None
|
||||
|
||||
if split_point is None:
|
||||
split_point = int(num_tokens_padded) // 2
|
||||
split_point = int(num_tokens_padded) // num_ubatches
|
||||
|
||||
token_split_points = [split_point * i for i in range(1, num_ubatches)]
|
||||
|
||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
||||
|
||||
first_ubatch_token_slice = slice(0, split_point)
|
||||
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
|
||||
ubatch_slices = []
|
||||
start_token = 0
|
||||
|
||||
# Determine request slices using exclusive stop semantics
|
||||
# First ubatch includes requests whose tokens overlap [0, split_point)
|
||||
first_ubatch_req_stop = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="left")
|
||||
)
|
||||
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
|
||||
# Add the end point to the split points to make iteration easier
|
||||
all_points = token_split_points + [cu_num_tokens[-1]]
|
||||
|
||||
# Second ubatch starts at the request that contains the split_point
|
||||
# or the request starting exactly at split_point (if on boundary)
|
||||
second_ubatch_req_start = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
|
||||
)
|
||||
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
|
||||
for end_token in all_points:
|
||||
token_slice = slice(start_token, end_token)
|
||||
|
||||
ubatch_slices = [
|
||||
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
||||
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
|
||||
]
|
||||
# Determine request slices using exclusive stop semantics
|
||||
# Ubatch includes requests whose tokens overlap [start_token, end_token)
|
||||
|
||||
# Start at the request that contains the start_token
|
||||
# or the request starting exactly at start_token (if on boundary)
|
||||
req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1)
|
||||
|
||||
# Stop at the request that starts at or after end_token
|
||||
req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left"))
|
||||
|
||||
req_slice = slice(req_start, req_stop)
|
||||
ubatch_slices.append(UBatchSlice(req_slice, token_slice))
|
||||
|
||||
start_token = end_token
|
||||
|
||||
ubatch_slices_padded = _pad_out_ubatch_slices(
|
||||
ubatch_slices, num_tokens_padded, num_reqs_padded
|
||||
|
||||
@ -7,10 +7,15 @@ import torch
|
||||
|
||||
from vllm import forward_context
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_THREAD_ID_TO_CONTEXT: dict = {}
|
||||
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None]
|
||||
# Here we hardcode the number of microbatches to 2 for default.
|
||||
_NUM_UBATCHES: int = 2
|
||||
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = []
|
||||
|
||||
|
||||
class UBatchContext:
|
||||
@ -48,6 +53,7 @@ class UBatchContext:
|
||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
|
||||
_CURRENT_CONTEXTS[self.id] = self
|
||||
# _NUM_UBATCHES is set in make_ubatch_contexts
|
||||
self.ready_barrier.wait()
|
||||
|
||||
self.cpu_wait_event.wait()
|
||||
@ -181,7 +187,7 @@ dbo_switch_to_compute_sync = _register_ubatch_function(
|
||||
def dbo_register_recv_hook(recv_hook):
|
||||
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2]
|
||||
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES]
|
||||
next_ctx.recv_hook = recv_hook
|
||||
|
||||
|
||||
@ -202,7 +208,14 @@ def make_ubatch_contexts(
|
||||
ready_barrier: threading.Barrier,
|
||||
schedule: str = "default",
|
||||
) -> list[UBatchContext]:
|
||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||
global _NUM_UBATCHES, _CURRENT_CONTEXTS
|
||||
assert num_micro_batches > 1, "num_micro_batches must be greater than 1"
|
||||
|
||||
_NUM_UBATCHES = num_micro_batches
|
||||
# Ensure the global context list is large enough
|
||||
if len(_CURRENT_CONTEXTS) < num_micro_batches:
|
||||
_CURRENT_CONTEXTS.extend([None] * (num_micro_batches - len(_CURRENT_CONTEXTS)))
|
||||
|
||||
"""
|
||||
Create a context manager for micro-batching synchronization.
|
||||
"""
|
||||
@ -210,8 +223,6 @@ def make_ubatch_contexts(
|
||||
gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
|
||||
gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
|
||||
|
||||
assert len(forward_contexts) == 2
|
||||
|
||||
ctxs = []
|
||||
for i in range(num_micro_batches):
|
||||
ctx = UBatchContext(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user