[Core/DBO][1/N] Add Dual-Batch Overlap mechanism to VLLM (#23693)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Sage Moore 2025-09-16 09:21:48 -07:00 committed by GitHub
parent 08369289af
commit 567939953b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1257 additions and 172 deletions

View File

@ -87,6 +87,11 @@ def parse_args():
default=0.8,
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
)
parser.add_argument(
"--enable-dbo",
action="store_true",
help=("Enable microbatched execution"),
)
parser.add_argument(
"--compilation-config",
type=int,
@ -113,6 +118,7 @@ def main(
max_model_len,
compilation_config,
gpu_memory_utilization,
enable_dbo,
quantization,
):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
@ -167,6 +173,7 @@ def main(
max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enable_dbo=enable_dbo,
quantization=quantization,
compilation_config=compilation_config,
)
@ -227,6 +234,7 @@ if __name__ == "__main__":
args.max_model_len,
args.compilation_config,
args.gpu_memory_utilization,
args.enable_dbo,
args.quantization,
),
)

View File

@ -6,7 +6,7 @@ import torch
from tests.v1.attention.test_attention_backends import BATCH_SPECS
from tests.v1.attention.utils import create_common_attn_metadata
from vllm.v1.attention.backends.utils import (UbatchSlice,
from vllm.v1.attention.backends.utils import (UBatchSlice,
_make_metadata_with_slice,
slice_query_start_locs,
split_attn_metadata)
@ -106,7 +106,7 @@ def mixed_small_metadata():
def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
"""Test slicing decode batch metadata"""
# Split first request only
ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1))
ubatch_slice = UBatchSlice(slice(0, 1), slice(0, 1))
result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata)
@ -120,7 +120,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
"""Test slicing mixed batch metadata"""
ubatch_slice = UbatchSlice(slice(1, 3),
ubatch_slice = UBatchSlice(slice(1, 3),
slice(1, 7)) # Requests 1-3, tokens 1-7
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
@ -137,8 +137,8 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
num_tokens = large_decode_metadata.num_reqs
mid_point = num_tokens // 2
ubatch_slices = [
UbatchSlice(slice(0, mid_point), slice(0, mid_point)),
UbatchSlice(slice(mid_point, num_tokens), slice(mid_point,
UBatchSlice(slice(0, mid_point), slice(0, mid_point)),
UBatchSlice(slice(mid_point, num_tokens), slice(mid_point,
num_tokens)),
]

View File

@ -365,7 +365,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
# Mock runner for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
proposer.runner.attn_groups[0][0].metadata_builders = [
attn_metadata_builder
]
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
@ -489,7 +491,9 @@ def test_propose_tree(spec_token_tree):
# Mock runner for attention metadata building.
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
proposer.runner.attn_groups[0][0].metadata_builders = [
attn_metadata_builder
]
# Setup inputs for the proposer.
target_token_ids = torch.randint(0,

View File

@ -2848,6 +2848,14 @@ class VllmConfig:
"when cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
if self.parallel_config.enable_dbo:
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
assert a2a_backend == "deepep_low_latency", \
"Microbatching currently only supports the deepep_low_latency "\
f"all2all backend. {a2a_backend} is not supported. To fix set "\
"the VLLM_ALL2ALL_BACKEND environment variable to "\
"deepep_low_latency and install the DeepEP kerenls."
if not self.instance_id:
self.instance_id = random_uuid()[:5]

View File

@ -137,6 +137,14 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""
enable_dbo: bool = False
"""Enable microbatching for the model executor."""
dbo_decode_token_threshold: int = 32
"""The threshold for microbatching. If the number of tokens in the
request is greater than this threshold, microbatching will be used.
Otherwise, the request will be processed in a single batch."""
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

View File

@ -251,9 +251,4 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
# It is dangerous to set num sms outside this function. num_sms is not
# a part of the hash-key that identifies this object. If we are in a
# situation where we make objects with different num_sms, the hash key
# in get_or_create must be updated.
handle.set_num_sms(self.num_sms)
return handle

View File

@ -327,6 +327,9 @@ class EngineArgs:
data_parallel_hybrid_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
enable_dbo: bool = ParallelConfig.enable_dbo
dbo_decode_token_threshold: int = \
ParallelConfig.dbo_decode_token_threshold
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
enable_eplb: bool = ParallelConfig.enable_eplb
expert_placement_strategy: ExpertPlacementStrategy = \
@ -695,6 +698,11 @@ class EngineArgs:
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument("--enable-dbo",
**parallel_kwargs["enable_dbo"])
parallel_group.add_argument(
"--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"])
parallel_group.add_argument("--enable-eplb",
**parallel_kwargs["enable_eplb"])
parallel_group.add_argument("--eplb-config",
@ -1339,6 +1347,8 @@ class EngineArgs:
data_parallel_backend=self.data_parallel_backend,
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_expert_parallel=self.enable_expert_parallel,
enable_dbo=self.enable_dbo,
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
enable_eplb=self.enable_eplb,
eplb_config=self.eplb_config,
expert_placement_strategy=self.expert_placement_strategy,

View File

@ -14,6 +14,7 @@ import vllm.envs as envs
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@ -97,6 +98,53 @@ class DPMetadata:
dist.all_reduce(num_tokens_tensor, group=group)
return num_tokens_tensor.cpu()
@staticmethod
def should_ubatch_across_dp(
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int, dp_size: int,
dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do. If this function decides
not to run with microbatching. It will "abort" meaning that no padding
information will be returned to the caller. It will return (False, None)
2. Determines the total number of tokens that each rank will run.
All ranks will be padded out so that the run with the same number
of tokens
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
None if should_ubatch if False
]
"""
device = current_platform.device_type
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(tensor, group=get_dp_group().device_group)
result: bool = bool(torch.all(tensor[2] == 1).item())
if not result:
return result, None
orig_num_tokens_tensor = tensor[0, :]
padded_num_tokens_tensor = tensor[1, :]
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
logger.debug("Aborting ubatching %s %s", orig_min_num_tokens,
padded_max_num_tokens)
return False, None
return result, padded_num_tokens_tensor.cpu()
@staticmethod
def make(
parallel_config: ParallelConfig,
@ -119,14 +167,15 @@ class DPMetadata:
# If num_tokens_across_dp is None, it will be computed by all_reduce
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
assert (num_tokens_across_dp is None
or num_tokens_across_dp[dp_rank] == batchsize)
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
if num_tokens_across_dp is None:
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
batchsize, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
num_tokens_across_dp)
@contextmanager
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
@ -179,9 +228,12 @@ class ForwardContext:
Type AttentionMetadata for v0,
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata
set dynamically for each forward pass
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
for each microbatch.
Set dynamically for each forward pass
"""
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"],
list[dict[str, "AttentionMetadata"]]]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
@ -191,6 +243,8 @@ class ForwardContext:
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
batch_descriptor: Optional[BatchDescriptor] = None
ubatch_slices: Optional[UBatchSlices] = None
def __post_init__(self):
assert self.cudagraph_runtime_mode in [
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
@ -208,6 +262,39 @@ def get_forward_context() -> ForwardContext:
return _forward_context
def create_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
dp_metadata: Optional[DPMetadata] = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
ubatch_slices: Optional[UBatchSlices] = None):
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices)
@contextmanager
def override_forward_context(forward_context: Optional[ForwardContext]):
"""A context manager that overrides the current forward context.
This is used to override the forward context for a specific
forward pass.
"""
global _forward_context
prev_context = _forward_context
_forward_context = forward_context
try:
yield
finally:
_forward_context = prev_context
@contextmanager
def set_forward_context(
attn_metadata: Any,
@ -216,7 +303,8 @@ def set_forward_context(
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None):
batch_descriptor: Optional[BatchDescriptor] = None,
ubatch_slices: Optional[UBatchSlices] = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
@ -225,6 +313,7 @@ def set_forward_context(
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None
if vllm_config.parallel_config.data_parallel_size > 1 and (
attn_metadata is not None or num_tokens is not None):
@ -232,20 +321,14 @@ def set_forward_context(
attn_metadata, num_tokens or 0,
num_tokens_across_dp)
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
no_compile_layers=vllm_config.compilation_config.
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
)
forward_context = create_forward_context(attn_metadata, vllm_config,
virtual_engine, dp_metadata,
cudagraph_runtime_mode,
batch_descriptor, ubatch_slices)
try:
yield
with override_forward_context(forward_context):
yield
finally:
global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize:
@ -282,5 +365,3 @@ def set_forward_context(
logger.info(("Batchsize forward time stats "
"(batchsize, count, median_time(ms)): %s"),
forward_stats)
_forward_context = prev_context

View File

@ -191,7 +191,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> Callable:
) -> tuple[Callable, mk.ReceiverType]:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
@ -217,13 +217,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1q_scale = None
a1_post_scale = a1_scale
return self._do_dispatch(tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config)
return (lambda *args: None,
self._do_dispatch(tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config))
def prepare(
self,
@ -237,10 +238,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
(_, receiver) = self.prepare_async(a1, a1_scale, a2_scale,
topk_weights, topk_ids, num_experts,
expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()
def finalize(

View File

@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input, normalize_batched_scales_shape)
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_yield)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128
@ -55,7 +58,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
self.handle = None
self.handles: list[Optional[tuple]] = [None, None]
self.num_dispatchers_ = num_dispatchers
def num_dispatchers(self) -> int:
@ -123,13 +126,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.ReceiverType:
) -> tuple[Callable, mk.ReceiverType]:
hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
(f"Hidden Size {hidden_size} not in supported list of hidden sizes"
f"{self.SUPPORTED_HIDDEN_SIZES}")
a2a_idx = dbo_current_ubatch_id()
if self.use_fp8_dispatch:
assert hidden_size % 128 == 0, \
"DeepEP kernels quantize the inputs in blocks of shape 128"
@ -148,7 +153,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1 = a1 * topk_weights.to(a1.dtype)
# Dispatch
expert_x, expert_num_tokens, self.handle, event, hook = \
expert_x, expert_num_tokens, handle, _, hook= \
self.buffer.low_latency_dispatch(a1,
topk_ids,
self.max_tokens_per_rank,
@ -156,21 +161,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
use_fp8=self.use_fp8_dispatch,
async_finish=False,
return_recv_hook=True)
self.handles[a2a_idx] = handle
return lambda: self._receiver(hook, expert_x, expert_num_tokens,
a1_scale, a1.dtype, quant_config)
return (hook, lambda: self._receiver(expert_x, expert_num_tokens,
a1_scale, a1.dtype, quant_config))
def _receiver(
self,
hook: Callable,
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
expert_num_tokens: torch.Tensor,
a1_scale,
a1_dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook()
expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
@ -192,10 +195,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
hook, receiver = self.prepare_async(a1, a1_scale, a2_scale,
topk_weights, topk_ids,
num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
hook()
return receiver()
def finalize(
@ -210,7 +215,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")
assert self.handle is not None
a2a_idx = dbo_current_ubatch_id()
do_recv_hook = dbo_enabled()
handle = self.handles[a2a_idx]
assert handle is not None
combine_topk_weights = topk_weights
if apply_router_weight_on_input:
@ -218,12 +227,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights = torch.ones_like(topk_weights)
# TODO (varun) : Enable zero copy mode
_, event, hook = self.buffer.low_latency_combine(
dbo_maybe_run_recv_hook()
_, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output,
topk_ids,
combine_topk_weights,
self.handle,
handle,
async_finish=False,
zero_copy=False,
return_recv_hook=False,
return_recv_hook=do_recv_hook,
out=output)
if recv_hook is not None:
dbo_register_recv_hook(recv_hook)
dbo_yield()

View File

@ -38,6 +38,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
round_up)
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
@ -992,16 +993,28 @@ class FusedMoE(CustomOp):
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.use_flashinfer_cutlass_kernels):
self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
if vllm_config.parallel_config.enable_dbo:
self.batched_hidden_states = torch.zeros(
(2, moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
# Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros(
(moe.max_num_tokens, num_experts),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
# Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros(
(2, moe.max_num_tokens, num_experts),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
else:
self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
# Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros(
(moe.max_num_tokens, num_experts),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
@ -1708,14 +1721,29 @@ class FusedMoE(CustomOp):
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
assert (self.batched_hidden_states.size(0) # type: ignore
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
# This is only true when DBO has been enabled in the config.
# Both tensors will have an outer dimension for the ubatch id
if self.batched_hidden_states.dim() == 3:
assert self.batched_router_logits.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
batched_hidden_states = self.batched_hidden_states[
batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[
batch_buffer_idx, :]
else:
batched_hidden_states = self.batched_hidden_states
batched_router_logits = self.batched_router_logits
assert (batched_hidden_states.size(0) # type: ignore
>= chunk_size)
assert (self.batched_router_logits.size(0) # type: ignore
assert (batched_router_logits.size(0) # type: ignore
>= chunk_size)
staged_hidden_states = self.batched_hidden_states[:
chunk_size, :] # type: ignore
staged_router_logits = self.batched_router_logits[:
chunk_size, :] # type: ignore
staged_hidden_states = batched_hidden_states[:
chunk_size, :] # type: ignore
staged_router_logits = batched_router_logits[:
chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)

View File

@ -13,6 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
_resize_cache, count_expert_num_tokens)
from vllm.utils import cdiv
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_yield)
#
# This file defines a set of base classes used to make MoE kernels more modular.
@ -226,7 +228,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> ReceiverType:
) -> tuple[Callable, ReceiverType]:
"""
Perform any quantization (and/or) dispatching needed for this kernel
but do not wait for results from other workers.
@ -496,6 +498,23 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
return None
class SharedResizableBuffer:
def __init__(self):
self.buffer = None
def get(self, shape: tuple[int, ...], device: torch.device,
dtype: torch.dtype):
shape_numel = prod(shape)
if self.buffer is None or self.buffer.numel() < shape_numel:
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
assert self.buffer.device == device, \
f"Buffer device mismatch: {self.buffer.device} != {device}"
assert self.buffer.dtype == dtype, \
f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}"
return self.buffer[:shape_numel].view(*shape)
@final
class FusedMoEModularKernel(torch.nn.Module):
"""
@ -509,6 +528,9 @@ class FusedMoEModularKernel(torch.nn.Module):
layer due to any layer specific state that may be used by the component
objects.
"""
fused_out_buffer = SharedResizableBuffer()
workspace13_buffer = SharedResizableBuffer()
workspace2_buffer = SharedResizableBuffer()
def __init__(
self,
@ -559,12 +581,12 @@ class FusedMoEModularKernel(torch.nn.Module):
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = torch.empty(prod(workspace13_shape),
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device=a1.device,
dtype=workspace_dtype)
workspace13 = self.workspace13_buffer.get(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = self.workspace2_buffer.get(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
assert fused_out is None or fused_out.shape == fused_out_shape, (
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
@ -656,9 +678,9 @@ class FusedMoEModularKernel(torch.nn.Module):
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
fused_out = torch.empty(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
fused_out = self.fused_out_buffer.get(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
def slice_input_tensors(
chunk_idx: int
@ -801,8 +823,10 @@ class FusedMoEModularKernel(torch.nn.Module):
shared_output: torch.Tensor
if (not self.prepare_finalize.supports_async()
or self.shared_experts is None):
if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
assert not dbo_enabled()
# Run shared experts serially with dispatch.
if self.shared_experts is not None:
@ -822,7 +846,8 @@ class FusedMoEModularKernel(torch.nn.Module):
)
else:
# Overlap shared expert compute with all2all dispatch.
receiver = self.prepare_finalize.prepare_async(
dbo_maybe_run_recv_hook()
hook, receiver = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
@ -834,8 +859,16 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.quant_config,
)
assert self.shared_experts is not None
shared_output = self.shared_experts(a1)
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
# If DBO is being used, register the hook with the ubatch context
# and call it in dbo_maybe_run_recv_hook instead of passing it to
# the receiver.
dbo_register_recv_hook(hook)
dbo_yield()
if not dbo_enabled():
hook()
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver()

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import Callable, Optional, Union
import pplx_kernels as pplx
import torch
@ -103,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.ReceiverType:
) -> tuple[Callable, mk.ReceiverType]:
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
@ -214,30 +214,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=False,
)
return lambda: self._receiver(
expert_num_tokens,
expert_x,
expert_x_scale,
a1q,
a1q_scale,
topk_ids,
bound_m,
orig_a_scale_block_shape,
)
def _receiver(
self,
expert_num_tokens: torch.Tensor,
expert_x: torch.Tensor,
expert_x_scale: Optional[torch.Tensor],
a1q: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
topk_ids: torch.Tensor,
bound_m: Optional[torch.Tensor],
orig_a_scale_block_shape: Optional[int],
) -> mk.PrepareResultType:
self.a2a.dispatch(
hook = lambda: self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
@ -249,6 +226,21 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=True,
)
return (hook, lambda: self._receiver(
expert_num_tokens,
expert_x,
expert_x_scale,
orig_a_scale_block_shape,
))
def _receiver(
self,
expert_num_tokens: torch.Tensor,
expert_x: torch.Tensor,
expert_x_scale: Optional[torch.Tensor],
orig_a_scale_block_shape: Optional[int],
) -> mk.PrepareResultType:
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
assert expert_x_scale.ndim == 3
@ -270,7 +262,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(
hook, receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
@ -281,6 +273,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input,
quant_config,
)
hook()
return receiver()
def finalize(

View File

@ -28,6 +28,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.ubatch_utils import UBatchSlice
logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"]
@ -81,12 +82,6 @@ class CommonAttentionMetadata:
encoder_seq_lens: Optional[np.ndarray] = None
@dataclass
class UbatchSlice:
request_slice: slice
token_slice: slice
def slice_query_start_locs(
query_start_loc: torch.Tensor,
request_slice: slice,
@ -103,7 +98,7 @@ def slice_query_start_locs(
def _make_metadata_with_slice(
ubatch_slice: UbatchSlice,
ubatch_slice: UBatchSlice,
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
"""
This function creates a new CommonAttentionMetadata that corresponds to
@ -133,6 +128,11 @@ def _make_metadata_with_slice(
torch.max(torch.abs(query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])).item())
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if max_query_len == 0:
max_query_len = attn_metadata.max_query_len
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
slot_mapping = attn_metadata.slot_mapping[token_slice]
@ -152,12 +152,12 @@ def _make_metadata_with_slice(
def split_attn_metadata(
ubatch_slices: list[UbatchSlice],
ubatch_slices: list[UBatchSlice],
common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UbatchSlice in ubatch_slices.
requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""

View File

@ -27,6 +27,7 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
@ -179,9 +180,11 @@ class EagleProposer:
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
ubatch_id = dbo_current_ubatch_id()
attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
@ -355,8 +358,9 @@ class EagleProposer:
hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
) -> list[torch.Tensor]:
ubatch_id = dbo_current_ubatch_id()
tree_attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builder
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
assert isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder)

View File

@ -64,8 +64,13 @@ class CPUModelRunner(GPUModelRunner):
if not self.attn_groups[0]:
return
mb = getattr(self.attn_groups[0][0], "metadata_builder", None)
if not isinstance(mb, TorchSDPAMetadataBuilderV1):
mb = getattr(self.attn_groups[0][0], "metadata_builders", None)
if isinstance(mb, list):
if not isinstance(mb[0], TorchSDPAMetadataBuilderV1):
return
mb[0].reorder_batch(self.input_batch, scheduler_output)
return
elif not isinstance(mb, TorchSDPAMetadataBuilderV1):
# Encoder-only / rerank models do not benefit from reordering,
# so we safely skip here.
return

View File

@ -15,6 +15,7 @@ import torch
import torch.distributed
import torch.nn as nn
from tqdm import tqdm
from typing_extensions import TypeAlias
import vllm.envs as envs
from vllm.attention import Attention, AttentionType
@ -55,11 +56,12 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up, supports_dynamo)
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
create_fast_prefill_custom_backend,
reorder_batch_to_split_decodes_and_prefills)
reorder_batch_to_split_decodes_and_prefills, split_attn_metadata)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
@ -85,9 +87,12 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from .utils import (AttentionGroup, MultiModalBudget,
@ -105,6 +110,11 @@ else:
logger = init_logger(__name__)
AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
# list when ubatching is enabled
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
AttnMetadataDict]
# Wrapper for ModelRunnerOutput to support overlapped execution.
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
@ -274,6 +284,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Request states.
self.requests: dict[str, CachedRequestState] = {}
self.comm_stream = torch.cuda.Stream()
# Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside
@ -872,10 +883,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return encoder_seq_lens
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
np.ndarray, Optional[CommonAttentionMetadata], int]:
self, scheduler_output: "SchedulerOutput"
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
Optional[SpecDecodeMetadata], np.ndarray,
Optional[CommonAttentionMetadata], int, Optional[UBatchSlices],
Optional[torch.Tensor]]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
@ -947,6 +959,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
num_tokens_unpadded)
ubatch_slices, num_tokens_after_padding = \
ubatch_split(max_num_scheduled_tokens,
num_tokens_unpadded,
num_tokens_padded,
self.vllm_config)
self.seq_lens.np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
@ -1001,7 +1022,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
logits_indices)
attn_metadata: dict[str, Any] = {}
attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
@ -1075,7 +1098,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for attn_group in self.attn_groups[kv_cache_group_id]:
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
builder = attn_group.metadata_builder
builder = attn_group.get_metadata_builder()
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
@ -1093,13 +1116,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
)
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
if ubatch_slices is not None:
common_attn_metadata_list = split_attn_metadata(
ubatch_slices, common_attn_metadata)
for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list):
assert common_attn_metadata.max_query_len == 1
attn_metadata_i = (attn_group.get_metadata_builder(
ubatch_id=ubid).build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i
else:
assert isinstance(attn_metadata, dict)
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
# Hot-Swap lora model
if self.lora_config:
@ -1107,7 +1144,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return (attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens)
max_num_scheduled_tokens, ubatch_slices,
num_tokens_after_padding)
def _compute_cascade_attn_prefix_len(
self,
@ -1508,7 +1546,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper.
if isinstance(self.model, CUDAGraphWrapper):
if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)):
return self.model.unwrap()
return self.model
@ -1675,6 +1713,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def get_dp_padding(self,
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
"""
Determines the total number of tokens that each rank will run.
All ranks will be padded out so that they run with the same number
of tokens
Returns: tuple[
num_pad_tokens: The number of tokens that will be added to the batch
num_tokens_after_padding: A tensor containing the total number of
tokens for each DP rank including padding.
]
"""
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
@ -1698,6 +1747,39 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def get_local_padding(self, num_tokens_unpadded: int) -> int:
num_tokens_padded = num_tokens_unpadded
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_tokens_padded = self.vllm_config.pad_for_cudagraph(
num_tokens_unpadded)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_tokens_padded = round_up(num_tokens_unpadded, tp_size)
num_pad_tokens = num_tokens_padded - num_tokens_unpadded
return num_pad_tokens
# This is where the second ubatch is adjusted to account for the padding.
# Should be called after attention metadata creation. This just pads
# the second ubatch slice out to the total number of tokens
# (num_tokens + padding)
def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices,
num_total_tokens: int):
padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start,
num_total_tokens)
ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice,
padded_second_ubatch_slice)
def _pool(
self,
hidden_states: torch.Tensor,
@ -1758,15 +1840,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
ubatch_slices: Optional[UBatchSlices] = None,
num_tokens_after_padding: Optional[torch.Tensor] = None,
) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor,
Optional[IntermediateTensors], dict[str, Any]]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
if ubatch_slices:
assert num_tokens_after_padding is not None
num_input_tokens = int(num_tokens_after_padding[0].item() * 2)
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
elif ubatch_slices is None:
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
num_pad, num_tokens_after_padding = self.get_dp_padding(
num_input_tokens)
num_input_tokens += num_pad
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
@ -1821,7 +1910,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return (
num_scheduled_tokens,
num_input_tokens,
num_tokens_across_dp,
num_tokens_after_padding,
input_ids,
inputs_embeds,
positions,
@ -2027,7 +2116,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = self._prepare_inputs(scheduler_output)
max_query_len, ubatch_slices, num_tokens_after_padding
) = self._prepare_inputs(scheduler_output)
finally:
if self.prepare_inputs_event is not None:
@ -2042,7 +2132,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions,
intermediate_tensors,
model_kwargs,
) = self._preprocess(scheduler_output, intermediate_tensors)
) = self._preprocess(scheduler_output, intermediate_tensors,
ubatch_slices, num_tokens_after_padding)
if ubatch_slices is not None:
num_input_tokens = num_input_tokens // 2
uniform_decode = (max_query_len
== self.uniform_decode_query_len) and (
@ -2062,6 +2156,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
), record_function_or_nullcontext("Forward"),
self.maybe_get_kv_connector_output(scheduler_output) as
kv_connector_output):
@ -2441,10 +2536,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# CudagraphWraper and CudagraphDispatcher of vllm.
# wrap the model with full cudagraph wrapper if needed.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \
and not self.parallel_config.enable_dbo:
self.model = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
elif self.parallel_config.enable_dbo:
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.model = UBatchWrapper(self.model, self.vllm_config,
CUDAGraphMode.FULL, self.device)
else:
self.model = UBatchWrapper(self.model, self.vllm_config,
CUDAGraphMode.NONE, self.device)
def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, \
@ -2642,6 +2745,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
force_attention: bool = False,
uniform_decode: bool = False,
allow_microbatching: bool = False,
skip_eplb: bool = False,
is_profile: bool = False,
create_mixed_batch: bool = False,
@ -2667,12 +2771,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
"""
ubatch_enabled = self.parallel_config.enable_dbo
num_tokens_across_dp = None
num_pad = 0
should_ubatch = False
if ubatch_enabled:
should_ubatch = num_tokens >= \
self.parallel_config.dbo_decode_token_threshold and \
allow_microbatching
(should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch(
num_tokens, num_tokens, should_ubatch, self.vllm_config)
# Currently the dummy run should only be ubatching during
# cuda graph capture, meaning all DP ranks should already
# have the same batch size
if num_tokens_across_dp is not None:
assert int(num_tokens_across_dp[0]) == num_tokens // 2
assert cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
if not should_ubatch:
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
# If cudagraph_mode.decode_mode() == FULL and
@ -2690,6 +2812,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# for GQA/MQA.
max_query_len = self.uniform_decode_query_len if uniform_decode else \
num_tokens
if allow_microbatching:
assert self.uniform_decode_query_len == 1
assert uniform_decode is True
assert max_query_len == 1
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
@ -2728,12 +2854,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
attn_metadata: Optional[dict[str, Any]] = None
ubatch_slices = None
# We currently only microbatch if the number of tokens is
# over a certain threshold.
if should_ubatch:
# We only support decode-only cudagraphs
assert num_reqs == num_tokens
assert num_tokens % 2 == 0
ubatch_slices = [
UBatchSlice(slice(0, num_reqs // 2), slice(0,
num_tokens // 2)),
UBatchSlice(slice(num_reqs // 2, num_reqs),
slice(num_tokens // 2, num_tokens))
]
attn_metadata: Optional[PerLayerAttnMetadata] = None
# If force_attention is True, we always capture attention. Otherwise,
# it only happens for cudagraph_runtime_mode=FULL.
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
attn_metadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
if create_mixed_batch:
# In the mixed batch mode (used for FI warmup), we use
@ -2766,12 +2908,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
slot_mapping=self.input_batch.
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
causal=True)
for attn_group in self.attn_groups[kv_cache_group_id]:
attn_metadata_i = attn_group.metadata_builder\
.build_for_cudagraph_capture(common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
if ubatch_slices is not None:
common_attn_metadata_list = split_attn_metadata(
ubatch_slices, common_attn_metadata)
for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list):
assert common_attn_metadata.max_query_len == 1
attn_metadata_i = (attn_group\
.get_metadata_builder(ubatch_id=ubid)\
.build_for_cudagraph_capture(common_attn_metadata))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][
layer_name] = attn_metadata_i
else:
assert type(attn_metadata) is dict
attn_metadata_i = attn_group.get_metadata_builder()\
.build_for_cudagraph_capture(common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens, remove_lora):
@ -2818,13 +2974,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
if ubatch_slices is not None:
num_tokens = num_tokens // 2
with self.maybe_randomize_inputs(input_ids), set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor):
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices):
outputs = self.model(
input_ids=input_ids,
positions=positions,
@ -3096,6 +3255,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
set_cudagraph_capturing_enabled(True)
with freeze_gc(), graph_capture(device=self.device):
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
@ -3153,6 +3313,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
desc="Capturing CUDA graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
cudagraph_runtime_mode.name))
enable_dbo = self.parallel_config.enable_dbo
# DBO Only supports running Full cudagraphs with uniform
# decode lengths
if enable_dbo and uniform_decode:
for num_tokens in compilation_cases:
# If the number of tokens is greater than the microbatching
# threshold, don't generate a microbatched cudagraph
if (num_tokens
< self.parallel_config.dbo_decode_token_threshold):
continue
# Warmup
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):
force_attention = (
cudagraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=True,
allow_microbatching=True,
skip_eplb=True)
# Graph Capture
self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True,
allow_microbatching=True,
skip_eplb=True)
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
@ -3219,14 +3408,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for attn_backend, layer_names in attn_backends_map.items():
attn_metadata_builder_i = attn_backend.get_builder_cls()(
attn_metadata_builders = []
attn_metadata_builders.append(attn_backend.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
)
))
if self.parallel_config.enable_dbo:
attn_metadata_builders.append(
attn_backend.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
))
attn_group = AttentionGroup(attn_backend,
attn_metadata_builder_i,
attn_metadata_builders,
layer_names)
attn_groups.append(attn_group)
return attn_groups
@ -3246,11 +3444,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
min_cg_builder_name = None
for attn_group in self._attn_group_iterator():
builder = attn_group.metadata_builder
builder = attn_group.get_metadata_builder()
if builder.cudagraph_support.value < min_cg_support.value:
min_cg_support = builder.cudagraph_support
min_cg_builder_name = builder.__class__.__name__
# Flexible resolve the cudagraph mode
cudagraph_mode = self.compilation_config.cudagraph_mode
# check cudagraph for mixed batch is supported
@ -3316,7 +3513,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is compatible (e.g., decode threshold is the same)
"""
for group in self._attn_group_iterator():
attn_metadata_builder_i = group.metadata_builder
attn_metadata_builder_i = group.get_metadata_builder()
# check that if any backends reorder batches; that the reordering
# is compatible (e.g., decode threshold is the same)

View File

@ -0,0 +1,303 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import threading
from typing import Any, Callable, Optional
import torch
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import (create_forward_context, get_forward_context,
override_forward_context)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
logger = init_logger(__name__)
@dataclasses.dataclass
class UbatchMetadata:
context: UBatchContext
input_ids: torch.Tensor
positions: torch.Tensor
inputs_embeds: Optional[torch.Tensor]
intermediate_tensors: Optional[IntermediateTensors]
num_tokens: int
@dataclasses.dataclass
class CUDAGraphMetaData:
cudagraph: torch.cuda.CUDAGraph
ubatch_metadata: UbatchMetadata
outputs: Optional[Any] = None
class UBatchWrapper:
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, device: torch.cuda.device):
self.runnable = runnable
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.comm_stream = torch.cuda.Stream(device=device)
# Two ubatch threads plus the main thread
self.ready_barrier = threading.Barrier(3)
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
self.cudagraph_wrapper = None
self.graph_pool = None
if runtime_mode is not CUDAGraphMode.NONE:
self.cudagraph_wrapper = CUDAGraphWrapper(
runnable, vllm_config, runtime_mode=runtime_mode)
self.graph_pool = current_platform.get_global_graph_pool()
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
"""
Capture a cudagraph for a microbatched run.
The logic here is somewhat complicated because we need to make sure that
each of the ubatch threads initialize the cuda context before we start
the graph capture.
The flow is as follows:
1. The main thread starts up each ubatch thread. Each thread will
initialize its cuda context (torch.cuda.current_blas_handle())
before going to sleep upon entering the ubatch_context.
2. The main thread starts the graph capture and wakes up the first
ubatch thread.
3. Each ubatch thread runs the model to completion and returns the
completed output tensors back to the main thread.
4. The main thread stores the captured cudagraph along with its metadata
and returns
"""
@torch.inference_mode()
def _capture_ubatch_thread(results, ubatch_metadata):
ubatch_context = ubatch_metadata.context
with torch.cuda.stream(ubatch_context.compute_stream):
_ = torch.cuda.current_blas_handle()
with torch.cuda.stream(ubatch_context.comm_stream):
_ = torch.cuda.current_blas_handle()
with ubatch_context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
compute_stream = ubatch_metadata[0].context.compute_stream
num_tokens = ubatch_metadata[0].num_tokens + \
ubatch_metadata[1].num_tokens
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(target=_capture_ubatch_thread,
args=(
results,
metadata,
))
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
# Capture the cudagraph
cudagraph_metadata = \
CUDAGraphMetaData(
cudagraph=torch.cuda.CUDAGraph(),
ubatch_metadata=ubatch_metadata,
)
with torch.cuda.graph(cudagraph_metadata.cudagraph,
stream=compute_stream,
pool=self.graph_pool):
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
cudagraph_metadata.outputs = result
self.cudagraphs[num_tokens] = cudagraph_metadata
return cudagraph_metadata.outputs
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
@torch.inference_mode()
def _ubatch_thread(results, model, ubatch_metadata):
with ubatch_metadata.context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
# Ubatch threads will manually manage the forward context, so we
# override it to None here so we can have it restored correctly
# after both threads have finished
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(target=_ubatch_thread,
args=(
results,
model,
metadata,
))
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
return result
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
positions, inputs_embeds, intermediate_tensors,
compute_stream, dp_metadata, batch_descriptor,
cudagraph_runtime_mode) -> list[UbatchMetadata]:
# Create one forward context per ubatch
forward_contexts = []
for i, ubatch_slice in enumerate(ubatch_slices):
forward_contexts.append(
create_forward_context(
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_runtime_mode))
ubatch_ctxs = make_ubatch_contexts(
num_micro_batches=len(ubatch_slices),
comm_stream=self.comm_stream,
compute_stream=compute_stream,
forward_contexts=forward_contexts,
ready_barrier=self.ready_barrier)
ubatch_metadata: list[UbatchMetadata] = []
for i, ubatch_slice in enumerate(ubatch_slices):
sliced_input_ids, sliced_positions, sliced_inputs_embeds, \
sliced_intermediate_tensors = \
self._slice_model_inputs(
ubatch_slice.token_slice, input_ids, positions,
inputs_embeds, intermediate_tensors)
ubatch_metadata.append(
UbatchMetadata(
context=ubatch_ctxs[i],
input_ids=sliced_input_ids,
positions=sliced_positions,
inputs_embeds=sliced_inputs_embeds,
intermediate_tensors=sliced_intermediate_tensors,
num_tokens=ubatch_slice.token_slice.stop -
ubatch_slice.token_slice.start))
return ubatch_metadata
def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions,
inputs_embeds, intermediate_tensors):
sliced_input_ids = input_ids[tokens_slice]
# if we are using mrope. Mrope adds an additional dimension to the
# positions tensor
if positions.ndim == 2:
sliced_positions = positions[:, tokens_slice]
else:
sliced_positions = positions[tokens_slice]
sliced_inputs_embeds = inputs_embeds[
tokens_slice] if inputs_embeds else None
sliced_intermediate_tensors = intermediate_tensors[
tokens_slice] if intermediate_tensors else None
return (sliced_input_ids, sliced_positions, sliced_inputs_embeds,
sliced_intermediate_tensors)
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
ubatch_slices = forward_context.ubatch_slices
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
# If there's no ubatching, just run the runnable object
if ubatch_slices is None:
if cudagraph_runtime_mode in (CUDAGraphMode.NONE,
CUDAGraphMode.PIECEWISE):
return self.runnable(*args, **kwargs)
else:
assert self.cudagraph_wrapper is not None
return self.cudagraph_wrapper(*args, **kwargs)
attn_metadata = forward_context.attn_metadata
num_tokens = (ubatch_slices[0].token_slice.stop -
ubatch_slices[0].token_slice.start) * 2
input_ids = kwargs['input_ids']
positions = kwargs['positions']
intermediate_tensors = kwargs['intermediate_tensors']
inputs_embeds = kwargs['inputs_embeds']
compute_stream = torch.cuda.current_stream()
dp_metadata = forward_context.dp_metadata
# We shouldn't be here unless we are running with multiple DP ranks
assert dp_metadata is not None
if num_tokens not in self.cudagraphs \
and cudagraph_runtime_mode is CUDAGraphMode.FULL:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE)
return self._capture_ubatches(ubatch_metadata, self.model)
elif num_tokens in self.cudagraphs:
cudagraph_metadata = self.cudagraphs[num_tokens]
cudagraph_metadata.cudagraph.replay()
return cudagraph_metadata.outputs
else:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE)
return self._run_ubatches(ubatch_metadata, self.model)

View File

@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.config import VllmConfig
from vllm.forward_context import DPMetadata
from vllm.logger import init_logger
from vllm.utils import round_up
from vllm.v1.worker.ubatch_utils import (UBatchSlice, UBatchSlices,
is_second_ubatch_empty)
logger = init_logger(__name__)
def should_ubatch_with_num_tokens(
should_ubatch: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
vllm_config: VllmConfig,
) -> tuple[bool, Optional[torch.Tensor]]:
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
return DPMetadata.should_ubatch_across_dp(should_ubatch,
orig_num_tokens_per_ubatch,
padded_num_tokens_per_ubatch,
dp_size, dp_rank)
def get_dp_padding_ubatch(
num_tokens_unpadded: int, num_tokens_padded: int,
should_attempt_ubatching: bool,
vllm_config: VllmConfig) -> tuple[bool, Optional[torch.Tensor]]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do. If this function decides
not to run with microbatching. It will "abort" meaning that no padding
information will be returned to the caller. It will return (False, None)
2. Determines the total number of tokens that each rank will run.
All ranks will be padded out so that the run with the same number
of tokens
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
None if should_ubatch if False
]
"""
assert num_tokens_padded >= num_tokens_unpadded
dp_size = vllm_config.parallel_config.data_parallel_size
if dp_size == 1:
# Early exit.
return False, None
# If this DP rank doesn't want to attempt microbatching
if not should_attempt_ubatching:
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
False, 0, 0, vllm_config)
assert should_ubatch is False
assert num_tokens_across_dp is None
return should_ubatch, num_tokens_across_dp
# Round up to the next multiple of two for even divisibility
num_tokens_padded = round_up(num_tokens_padded, 2)
num_tokens_per_ubatch = num_tokens_padded // 2
should_ubatch = True
# Sanity Check that the existing padding isn't giving us an empty second
# ubatch. Abort if so
if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded):
logger.debug(
"Empty second µbatch detected: unpadded tokens: %s, padded "
"tokens: %s", num_tokens_unpadded, num_tokens_padded)
should_ubatch = False
# Note that we compute the number of padded tokens per ubatch
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch,
vllm_config)
if not should_ubatch:
assert num_tokens_across_dp is None
return should_ubatch, num_tokens_across_dp
assert num_tokens_across_dp is not None
max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item())
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
dp_size,
device="cpu",
dtype=torch.int32)
return should_ubatch, num_tokens_after_padding
def ubatch_split(
max_num_scheduled_tokens: int,
num_tokens_unpadded: int,
num_tokens_padded: int,
vllm_config: VllmConfig,
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
None if ubatch_slices is None
]
"""
parallel_config = vllm_config.parallel_config
# Don't bother with the should_ubatch handshaking unless microbatching
# is enabled
if not parallel_config.enable_dbo:
return (None, None)
# Check preconditions for microbatching
should_attempt_ubatching = \
parallel_config.enable_dbo and \
num_tokens_unpadded >= \
parallel_config.dbo_decode_token_threshold \
and max_num_scheduled_tokens == 1
# Don't microbatch unless every other DP worker is also microbatching
num_tokens_after_padding = None
(should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch(
num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching,
vllm_config)
if not should_ubatch:
return (None, None)
# This doesn't actually pad the ubatch slices. It just initializes the
# split point to the padded value so that padding can be applied
# to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation
assert num_tokens_after_padding is not None
total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item())
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
num_tokens_unpadded)
# Note there's an assumption here that there's 1 token per request
ubatch_slices = [
UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice),
UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice)
]
return (ubatch_slices, num_tokens_after_padding)

View File

@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing_extensions import TypeAlias
@dataclass
class UBatchSlice:
request_slice: slice
token_slice: slice
UBatchSlices: TypeAlias = list[UBatchSlice]
def is_second_ubatch_empty(orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int) -> bool:
return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch

211
vllm/v1/worker/ubatching.py Normal file
View File

@ -0,0 +1,211 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Optional
import torch
from vllm import forward_context
from vllm.forward_context import ForwardContext
from vllm.utils import current_stream
_THREAD_ID_TO_CONTEXT: dict = {}
_CURRENT_CONTEXTS: list[Optional['UBatchContext']] = [None, None]
class UBatchContext:
"""
Context manager for micro-batching synchronization using threading events.
"""
def __init__(self,
id: int,
comm_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream,
forward_context: ForwardContext,
ready_barrier: threading.Barrier,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event,
schedule: str = "default"):
self.id = id
self.comm_stream = comm_stream
self.compute_stream = compute_stream
self.forward_context = forward_context
self.ready_barrier = ready_barrier
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
self.current_stream = compute_stream
self.gpu_comm_done_event = gpu_comm_done_event
self.gpu_compute_done_event = gpu_compute_done_event
self.schedule = schedule
self.recv_hook = None
def __enter__(self):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
_CURRENT_CONTEXTS[self.id] = self
self.ready_barrier.wait()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
# Assume we start on the compute stream
assert current_stream() == self.compute_stream
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_CURRENT_CONTEXTS[self.id] = None
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
self.maybe_run_recv_hook()
self.cpu_signal_event.set()
self.cpu_wait_event.clear()
self.current_stream = self.compute_stream
torch.cuda.set_stream(self.current_stream)
return False
def _restore_context(self):
forward_context._forward_context = self.forward_context
torch.cuda.set_stream(self.current_stream)
def update_stream(self, stream):
self.current_stream = stream
torch.cuda.set_stream(self.current_stream)
def _signal_comm_done(self):
self.gpu_comm_done_event.record(self.comm_stream)
def _signal_compute_done(self):
self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self):
self.comm_stream.wait_event(self.gpu_compute_done_event)
def _wait_comm_done(self):
self.compute_stream.wait_event(self.gpu_comm_done_event)
def _cpu_yield(self):
# It is critical for correctness that only one thread is running
# at a time. These asserts just make sure that this is the only
# thread running before waking the other one up and going to sleep
assert forward_context._forward_context == self.forward_context
assert current_stream() == self.current_stream
assert not self.cpu_wait_event.is_set()
self.cpu_signal_event.set()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
def switch_to_comm_sync(self):
self._signal_compute_done()
self.update_stream(self.comm_stream)
self._wait_comm_done()
def maybe_run_recv_hook(self):
if self.recv_hook is not None:
self.recv_hook()
self.recv_hook = None
def yield_(self):
self.current_stream = current_stream()
self._cpu_yield()
if self.current_stream != current_stream():
self.update_stream(self.current_stream)
def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream
self._signal_compute_done()
self._cpu_yield()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
self._signal_comm_done()
self._cpu_yield()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
self._wait_comm_done()
def dbo_enabled() -> bool:
return len(_THREAD_ID_TO_CONTEXT) > 0
def dbo_current_ubatch_id() -> int:
if len(_THREAD_ID_TO_CONTEXT) == 0:
return 0
return _THREAD_ID_TO_CONTEXT[threading.get_ident()]
def _register_ubatch_function(func):
def wrapper(*args, **kwargs):
if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
ctx = _CURRENT_CONTEXTS[ctx_idx]
func(ctx, *args, **kwargs)
return wrapper
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
UBatchContext.yield_and_switch_from_compute_to_comm)
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
UBatchContext.yield_and_switch_from_comm_to_compute)
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
dbo_maybe_run_recv_hook = _register_ubatch_function(
UBatchContext.maybe_run_recv_hook)
dbo_switch_to_comm_sync = _register_ubatch_function(
UBatchContext.switch_to_comm_sync)
def dbo_register_recv_hook(recv_hook):
if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2]
next_ctx.recv_hook = recv_hook
def make_ubatch_contexts(
num_micro_batches: int,
compute_stream: torch.cuda.Stream,
comm_stream: torch.cuda.Stream,
forward_contexts: list[ForwardContext],
ready_barrier: threading.Barrier,
schedule: str = "default",
) -> list[UBatchContext]:
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
"""
Create a context manager for micro-batching synchronization.
"""
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
gpu_comm_done_events = [
torch.cuda.Event() for _ in range(num_micro_batches)
]
gpu_compute_done_events = [
torch.cuda.Event() for _ in range(num_micro_batches)
]
assert len(forward_contexts) == 2
ctxs = []
for i in range(num_micro_batches):
ctx = UBatchContext(id=i,
compute_stream=compute_stream,
comm_stream=comm_stream,
forward_context=forward_contexts[i],
ready_barrier=ready_barrier,
cpu_wait_event=cpu_events[i],
cpu_signal_event=cpu_events[(i + 1) %
num_micro_batches],
gpu_comm_done_event=gpu_comm_done_events[i],
gpu_compute_done_event=gpu_compute_done_events[i],
schedule=schedule)
ctxs.append(ctx)
return ctxs

View File

@ -130,9 +130,17 @@ class MultiModalBudget:
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
metadata_builder: AttentionMetadataBuilder
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str]
def get_metadata_builder(self,
ubatch_id: Optional[int] = None
) -> AttentionMetadataBuilder:
if ubatch_id is None:
return self.metadata_builders[0]
assert len(self.metadata_builders) > ubatch_id
return self.metadata_builders[ubatch_id]
def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings,