[Core/DBO][2/N] Dual-Batch Overlap add DeepEP High Throughput support and Prefill support (#24845)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-09-23 12:02:10 -04:00 committed by GitHub
parent a903669e10
commit cc1dc7ed6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 602 additions and 236 deletions

View File

@ -5,11 +5,12 @@ import pytest
import torch
from tests.v1.attention.test_attention_backends import BATCH_SPECS
from tests.v1.attention.utils import create_common_attn_metadata
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm.v1.attention.backends.utils import (UBatchSlice,
_make_metadata_with_slice,
slice_query_start_locs,
split_attn_metadata)
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
@pytest.fixture
@ -155,3 +156,83 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
assert results[1].num_reqs == mid_point
assert results[1].num_actual_tokens == mid_point
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
@pytest.mark.parametrize(
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
[
# Split in the middle of request 1
([32, 40], [8, 8], 12, 2, 1),
# Split inside the first request
([32, 40], [8, 8], 4, 1, 2),
],
)
def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point,
expected_first_reqs,
expected_second_reqs):
"""Test splitting a prefill across ubatches"""
import numpy as np
device = torch.device("cpu")
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens)
common = create_common_attn_metadata(batch_spec,
block_size=16,
device=device)
num_scheduled_tokens = np.array(query_lens, dtype=np.int32)
qsl_np = common.query_start_loc_cpu.numpy()
num_tokens = common.num_actual_tokens
ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point)
assert len(ubatch_slices) == 2
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
# Token counts match the split
assert first_meta.num_actual_tokens == split_point
assert second_meta.num_actual_tokens == num_tokens - split_point
# Number of requests per ubatch
assert first_meta.num_reqs == expected_first_reqs
assert second_meta.num_reqs == expected_second_reqs
# Identify which request is split and how many tokens are in the first chunk
split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1)
tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx])
orig_q_lens = (common.query_start_loc_cpu[1:] -
common.query_start_loc_cpu[:-1])
# Check query length continuity: first-chunk + second-chunk == original qlen
# First ubatch last request query length
qlen_first_last = int(first_meta.query_start_loc_cpu[-1] -
first_meta.query_start_loc_cpu[-2])
# Second ubatch first request query length
qlen_second_first = int(second_meta.query_start_loc_cpu[1] -
second_meta.query_start_loc_cpu[0])
assert qlen_first_last == tokens_in_first_chunk
assert qlen_first_last + qlen_second_first == int(
orig_q_lens[split_req_idx])
# Check seq_lens adjustments
# Context lengths per original request
context_lens = [s - q for s, q in zip(seq_lens, query_lens)]
# First ubatch: last request's seq_len should be
# context + tokens_in_first_chunk
expected_seqlen = context_lens[split_req_idx] + tokens_in_first_chunk
assert int(first_meta.seq_lens[-1]) == expected_seqlen
# For full preceding requests in first ubatch, seq_lens should match
# originals
for i in range(first_meta.num_reqs - 1):
assert int(first_meta.seq_lens[i]) == seq_lens[i]
# Second ubatch: first request (continuation) seq_len should be full
# original
assert int(second_meta.seq_lens[0]) == seq_lens[split_req_idx]
# Any following full requests in second ubatch should match originals
for j in range(1, second_meta.num_reqs):
# Map to original request index
orig_idx = split_req_idx + j
assert int(second_meta.seq_lens[j]) == seq_lens[orig_idx]

View File

@ -532,9 +532,8 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
# Mock runner for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
]
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
@ -659,9 +658,8 @@ def test_propose_tree(spec_token_tree):
# Mock runner for attention metadata building.
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
]
# Setup inputs for the proposer.
target_token_ids = torch.randint(0,

View File

@ -638,11 +638,13 @@ class VllmConfig:
if self.parallel_config.enable_dbo:
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
assert a2a_backend == "deepep_low_latency", \
"Microbatching currently only supports the deepep_low_latency "\
f"all2all backend. {a2a_backend} is not supported. To fix set "\
"the VLLM_ALL2ALL_BACKEND environment variable to "\
"deepep_low_latency and install the DeepEP kerenls."
assert a2a_backend in \
["deepep_low_latency", "deepep_high_throughput"], \
"Microbatching currently only supports the deepep_low_latency and "\
f"deepep_high_throughput all2all backend. {a2a_backend} is not "\
"supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\
"variable to deepep_low_latency or deepep_high_throughput and "\
"install the DeepEP kernels."
if not self.instance_id:
self.instance_id = random_uuid()[:5]

View File

@ -139,12 +139,18 @@ class ParallelConfig:
"""Disable the custom all-reduce kernel and fall back to NCCL."""
enable_dbo: bool = False
"""Enable microbatching for the model executor."""
"""Enable dual batch overlap for the model executor."""
dbo_decode_token_threshold: int = 32
"""The threshold for microbatching. If the number of tokens in the
request is greater than this threshold, microbatching will be used.
Otherwise, the request will be processed in a single batch."""
"""The threshold for dual batch overlap for batches only containing decodes.
If the number of tokens in the request is greater than this threshold,
microbatching will be used. Otherwise, the request will be processed in a
single batch."""
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
"""The threshold for dual batch overlap for batches that contain one or more
prefills. If the number of tokens in the request is greater than this
threshold, microbatching will be used. Otherwise, the request will be
processed in a single batch."""
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

View File

@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from typing import Any, Optional
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
@ -200,12 +201,12 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_rdma_bytes = None
num_qps_per_rank = None
if self.internode:
num_rdma_bytes = 1024 * 1024 * 1024
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_qps_per_rank = self.num_sms // 2
else:
num_rdma_bytes = 0
@ -230,13 +231,18 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
# It is dangerous to set num sms outside this function. num_sms is not
# a part of the hash-key that identifies this object. If we are in a
# situation where we make objects with different num_sms, the hash key
# in get_or_create must be updated.
handle.set_num_sms(self.num_sms)
return handle
def set_num_sms(self, num_sms: int):
import deep_ep
# Right now the buffers are sized for only what the kernels were
# created with. So we can only reduce the number of SMS used
# but not increase it.
if num_sms > self.num_sms:
num_sms = self.num_sms
deep_ep.Buffer.set_num_sms(num_sms)
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
"""
@ -265,7 +271,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
import deep_ep
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_qps_per_rank = num_local_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
@ -291,3 +297,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
return handle
# DeepEP LL uses RDMA so no SMs are used for communication
def max_sms_used(self) -> Optional[int]:
return 0

View File

@ -60,6 +60,12 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise NotImplementedError
def set_num_sms(self, num_sms: int):
pass
def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError

View File

@ -330,6 +330,8 @@ class EngineArgs:
enable_dbo: bool = ParallelConfig.enable_dbo
dbo_decode_token_threshold: int = \
ParallelConfig.dbo_decode_token_threshold
dbo_prefill_token_threshold: int = \
ParallelConfig.dbo_prefill_token_threshold
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
enable_eplb: bool = ParallelConfig.enable_eplb
expert_placement_strategy: ExpertPlacementStrategy = \
@ -698,6 +700,9 @@ class EngineArgs:
parallel_group.add_argument(
"--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"])
parallel_group.add_argument(
"--dbo-prefill-token-threshold",
**parallel_kwargs["dbo_prefill_token_threshold"])
parallel_group.add_argument("--enable-eplb",
**parallel_kwargs["enable_eplb"])
parallel_group.add_argument("--eplb-config",
@ -1316,6 +1321,7 @@ class EngineArgs:
enable_expert_parallel=self.enable_expert_parallel,
enable_dbo=self.enable_dbo,
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
enable_eplb=self.enable_eplb,
eplb_config=self.eplb_config,
expert_placement_strategy=self.expert_placement_strategy,

View File

@ -189,6 +189,8 @@ if TYPE_CHECKING:
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024
VLLM_DBO_COMM_SMS: int = 20
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
@ -1392,6 +1394,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
# The size in MB of the buffers (NVL and RDMA) used by DeepEP
"VLLM_DEEPEP_BUFFER_SIZE_MB":
lambda: int(os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024")),
# The number of SMs to allocate for communication kernels when running DBO
# the rest of the SMs on the device will be allocated to compute
"VLLM_DBO_COMM_SMS":
lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")),
# Valid values are container,code_interpreter,web_search_preview
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":

View File

@ -12,6 +12,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.utils import round_up
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id, dbo_enabled, dbo_switch_to_comm,
dbo_switch_to_compute, dbo_switch_to_compute_sync,
dbo_yield_and_switch_from_comm_to_compute,
dbo_yield_and_switch_from_compute_to_comm)
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
@ -46,9 +51,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.async_prepare = True
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
self.handle = None
# requires. Under DBO microbatching we must track one handle per
# micro-batch to avoid races between threads.
self.handles = [None, None]
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
@ -89,6 +94,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
has_scales = token_scales is not None
# We yield before launching the dispatch kernel since the dispatch
# kernel will block the CPU so we want to queue up all the compute
# for the other ubatch before the dispatch kernel starts.
dbo_yield_and_switch_from_compute_to_comm()
(num_tokens_per_rank, num_tokens_per_rdma_rank,
dispatch_expert_num_tokens, is_token_in_rank,
event) = self.buffer.get_dispatch_layout(
@ -104,7 +114,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
(
token_data, expert_topk_ids, expert_topk_weights,
expert_num_tokens_per_expert_list, self.handle, event
expert_num_tokens_per_expert_list, handle, event
) = self.buffer.dispatch(
x=token_data,
handle=None,
@ -119,9 +129,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_alignment=1,
config=self._get_dispatch_config(),
previous_event=None,
async_finish=self.async_prepare,
async_finish=self.async_prepare and not dbo_enabled(),
allocate_on_comm_stream=False)
# record the handle for this ubatch
a2a_idx = dbo_current_ubatch_id()
self.handles[a2a_idx] = handle
dbo_switch_to_compute_sync()
return lambda: self._receiver(
event,
has_scales,
@ -146,7 +162,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1_scale: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if self.async_prepare:
if event.event is not None:
event.current_stream_wait()
if has_scales:
@ -207,7 +223,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[Callable, mk.ReceiverType]:
) -> mk.ReceiverType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
@ -233,14 +249,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1q_scale = None
a1_post_scale = quant_config.a1_scale
return (lambda *args: None,
self._do_dispatch(tokens=a1q,
return self._do_dispatch(tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config))
quant_config=quant_config)
def prepare(
self,
@ -252,9 +267,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
(_, receiver) = self.prepare_async(a1, topk_weights, topk_ids,
num_experts, expert_map,
apply_router_weight_on_input,
receiver = self.prepare_async(a1, topk_weights, topk_ids, num_experts,
expert_map, apply_router_weight_on_input,
quant_config)
return receiver()
@ -269,7 +283,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_async: bool,
) -> Optional[Callable]:
assert self.handle is not None
a2a_idx = dbo_current_ubatch_id()
handle = self.handles[a2a_idx]
assert handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
@ -283,25 +299,35 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
dbo_yield_and_switch_from_compute_to_comm()
combined_x, _, event = self.buffer.combine(
x=fused_expert_output,
handle=self.handle,
handle=handle,
topk_weights=None,
config=self._get_combine_config(),
previous_event=None,
async_finish=do_async,
async_finish=do_async and not dbo_enabled(),
allocate_on_comm_stream=False)
dbo_switch_to_compute()
if do_async:
def _receiver():
if event.event is not None:
event.current_stream_wait()
dbo_switch_to_comm()
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
return lambda: _receiver()
# TODO(lucas): refactor the modular kernel so this will be
# handled there
dbo_yield_and_switch_from_comm_to_compute()
return _receiver
else:
# TODO(lucas): support this case with the refactored modular kernel
assert not dbo_enabled()
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
return None

View File

@ -206,7 +206,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
do_async: bool,
) -> Optional[Callable]:
) -> tuple[Callable, Callable]:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")
@ -233,7 +233,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return_recv_hook=do_recv_hook,
out=output)
return recv_hook
return recv_hook, lambda: None
def finalize_async(
self,
@ -243,8 +243,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
recv_hook = self._finalize(
) -> tuple[Callable, Callable]:
return self._finalize(
output,
fused_expert_output,
topk_weights,
@ -253,8 +253,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
weight_and_reduce_impl,
do_async=True,
)
assert recv_hook is not None
return recv_hook
def finalize(
self,

View File

@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
_resize_cache, count_expert_num_tokens)
from vllm.utils import cdiv
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_yield)
#
@ -223,7 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[Callable, ReceiverType]:
) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
"""
Perform any quantization (and/or) dispatching needed for this kernel
but do not wait for results from other workers.
@ -239,10 +240,21 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `prepare`, e.g.
Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as
`prepare`, if a hook is returned this is more lightweight check that
the recv is complete without doing extra work (used by DBO, will be
refactored in the very near future)
receiver = obj.prepare_async(...)
e.g.
ret = obj.prepare_async(...)
if isinstance(ret, tuple):
hook, receiver = ret
hook()
if hook is not None:
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
is equivalent to:
@ -284,7 +296,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> Callable:
) -> Union[tuple[Callable, Callable], Callable]:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output but do not wait for results from other workers.
@ -298,11 +310,17 @@ class FusedMoEPrepareAndFinalize(ABC):
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `finalize`, e.g.
Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as
`finalize`, if a hook is returned this is more lightweight check that
the recv is complete without doing extra work (used by DBO, will be
refactored in the very near future)
receiver = obj.finalize_async(output, ...)
ret = obj.finalize_async(output, ...)
... output not valid yet ...
if isinstance(ret, tuple):
hook, receiver = ret
hook()
receiver()
... output valid here ...
@ -600,9 +618,23 @@ class FusedMoEModularKernel(torch.nn.Module):
layer due to any layer specific state that may be used by the component
objects.
"""
fused_out_buffer = SharedResizableBuffer()
workspace13_buffer = SharedResizableBuffer()
workspace2_buffer = SharedResizableBuffer()
class SharedBuffers:
def __init__(self) -> None:
self.fused_out = SharedResizableBuffer()
self.workspace13 = SharedResizableBuffer()
self.workspace2 = SharedResizableBuffer()
# Persistent buffers that are shared across `FusedMoEModularKernel`
# instances (layers), to save memory and allocattions.
#
# We have two sets of buffers to support dual batch overlap (DBO) where each
# microbatch (ubatch) should use its own set of buffers to avoid
# cross-ubatch contimination.
# NOTE that memory is lazily allocated for these buffers, meaning that if
# DBO isn't being used, the second SharedBuffers will be empty.
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
def __init__(
self,
@ -647,12 +679,16 @@ class FusedMoEModularKernel(torch.nn.Module):
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
ubatch_idx = dbo_current_ubatch_id()
buffers = self.shared_buffers[ubatch_idx]
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = self.workspace13_buffer.get(workspace13_shape,
workspace13 = buffers.workspace13.get(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = self.workspace2_buffer.get(workspace2_shape,
workspace2 = buffers.workspace2.get(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
@ -733,7 +769,9 @@ class FusedMoEModularKernel(torch.nn.Module):
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
fused_out = self.fused_out_buffer.get(fused_out_shape,
ubatch_idx = dbo_current_ubatch_id()
buffers = self.shared_buffers[ubatch_idx]
fused_out = buffers.fused_out.get(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
@ -868,6 +906,7 @@ class FusedMoEModularKernel(torch.nn.Module):
if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
# TODO(lucas): enable in follow-up
assert not dbo_enabled()
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
@ -883,7 +922,7 @@ class FusedMoEModularKernel(torch.nn.Module):
else:
# Overlap shared expert compute with all2all dispatch.
dbo_maybe_run_recv_hook()
hook, receiver = self.prepare_finalize.prepare_async(
prepare_ret = self.prepare_finalize.prepare_async(
a1,
topk_weights,
topk_ids,
@ -893,12 +932,20 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.quant_config,
)
# If DBO is being used, register the hook with the ubatch context
# and call it in dbo_maybe_run_recv_hook instead of passing it to
# the receiver.
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook, receiver = prepare_ret \
if isinstance(prepare_ret, tuple) else (None, prepare_ret)
if hook is not None:
if dbo_enabled():
# If DBO is being used, register the hook with the ubatch
# context and call it in dbo_maybe_run_recv_hook instead of
# passing it to the receiver.
dbo_register_recv_hook(hook)
dbo_yield()
if not dbo_enabled():
else:
hook()
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
@ -952,7 +999,7 @@ class FusedMoEModularKernel(torch.nn.Module):
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
else:
recv_hook = self.prepare_finalize.finalize_async(
finalize_ret = self.prepare_finalize.finalize_async(
output,
fused_out,
topk_weights,
@ -964,11 +1011,23 @@ class FusedMoEModularKernel(torch.nn.Module):
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
assert recv_hook is not None
dbo_register_recv_hook(recv_hook)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook, receiver = finalize_ret \
if isinstance(finalize_ret, tuple) else (None, finalize_ret)
if hook is not None:
if dbo_enabled():
# If DBO is being used, register the hook with the ubatch
# context and call it in dbo_maybe_run_recv_hook instead of
# passing it to the receiver.
dbo_register_recv_hook(hook)
dbo_yield()
if not dbo_enabled():
recv_hook()
else:
hook()
receiver()
if self.shared_experts is None:
return output

View File

@ -107,19 +107,57 @@ def _make_metadata_with_slice(
the requests included in ubatch_slice
"""
assert not ubatch_slice.is_empty(), (
f"Ubatch slice {ubatch_slice} is empty")
request_slice = ubatch_slice.request_slice
token_slice = ubatch_slice.token_slice
start_locs = attn_metadata.query_start_loc_cpu
first_req = request_slice.start
first_tok = token_slice.start
last_req = request_slice.stop - 1
last_tok = token_slice.stop - 1
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], \
"Token slice start outside of first request"
assert start_locs[last_req] <= last_tok < start_locs[last_req+1], \
"Token slice end outside of last request"
# If the "middle" request has tokens in both ubatches, we have to split it.
# If ubatch_slice is the first ubatch then we will be splitting the last
# request. If it's the second microbatch, then we will be splitting the
# first request
splits_first_request = first_tok > start_locs[first_req]
splits_last_request = last_tok < start_locs[last_req + 1] - 1
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
request_slice)
assert len(query_start_loc) >= 2, (
f"query_start_loc must have at least 2 elements, "
f"got {len(query_start_loc)}")
query_start_loc_cpu = slice_query_start_locs(
attn_metadata.query_start_loc_cpu, request_slice)
if splits_first_request:
tokens_skipped = first_tok - start_locs[first_req]
query_start_loc[1:] -= tokens_skipped
query_start_loc_cpu[1:] -= tokens_skipped
seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
if splits_last_request:
tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
query_start_loc[-1] -= tokens_skipped
query_start_loc_cpu[-1] -= tokens_skipped
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens = seq_lens.clone()
seq_lens_cpu = seq_lens_cpu.clone()
seq_lens[-1] -= tokens_skipped
seq_lens_cpu[-1] -= tokens_skipped
max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
request_slice]
@ -167,6 +205,7 @@ def split_attn_metadata(
for ubatch_slice in ubatch_slices:
results.append(
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
return results
@ -696,7 +735,6 @@ def split_decodes_and_prefills(
return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[first_prefill:] > decode_threshold)
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes

View File

@ -30,7 +30,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
@ -192,9 +191,8 @@ class EagleProposer:
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
ubatch_id = dbo_current_ubatch_id()
attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0)
@ -330,7 +328,7 @@ class EagleProposer:
# Rebuild attention metadata
attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = attn_metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
@ -538,9 +536,8 @@ class EagleProposer:
hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
) -> list[torch.Tensor]:
ubatch_id = dbo_current_ubatch_id()
tree_attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
self.runner.attn_groups[0][0].get_metadata_builder()
assert isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder)

View File

@ -96,7 +96,8 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split
from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds,
ubatch_split)
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
from vllm.v1.worker.utils import is_residual_scattered_for_sp
@ -1032,7 +1033,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
num_tokens_unpadded)
ubatch_slices, num_tokens_after_padding = \
ubatch_split(max_num_scheduled_tokens,
ubatch_split(num_scheduled_tokens,
num_tokens_unpadded,
num_tokens_padded,
self.vllm_config)
@ -1206,7 +1207,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
ubatch_slices, common_attn_metadata)
for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list):
assert common_attn_metadata.max_query_len == 1
attn_metadata_i = (attn_group.get_metadata_builder(
ubatch_id=ubid).build(
common_prefix_len=common_prefix_len,
@ -2182,9 +2182,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) = self._preprocess(scheduler_output, intermediate_tensors,
ubatch_slices, num_tokens_after_padding)
if ubatch_slices is not None:
num_input_tokens = num_input_tokens // 2
uniform_decode = (max_query_len
== self.uniform_decode_query_len) and (
num_scheduled_tokens
@ -2194,6 +2191,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor)
# This is currently to get around the assert in the DPMetadata
# where it wants `num_tokens_across_dp` to align with `num_tokens`
if ubatch_slices is not None:
num_input_tokens = ubatch_slices[0].num_tokens
# Run the model.
# Use persistent buffers for CUDA graphs.
with (set_forward_context(
@ -2821,7 +2823,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
force_attention: bool = False,
uniform_decode: bool = False,
allow_microbatching: bool = False,
allow_microbatching: bool = True,
skip_eplb: bool = False,
is_profile: bool = False,
create_mixed_batch: bool = False,
@ -2847,32 +2849,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
"""
ubatch_enabled = self.parallel_config.enable_dbo
num_tokens_across_dp = None
num_pad = 0
should_ubatch = False
if ubatch_enabled:
should_ubatch = num_tokens >= \
self.parallel_config.dbo_decode_token_threshold and \
allow_microbatching
(should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch(
num_tokens, num_tokens, should_ubatch, self.vllm_config)
# Currently the dummy run should only be ubatching during
# cuda graph capture, meaning all DP ranks should already
# have the same batch size
if num_tokens_across_dp is not None:
assert int(num_tokens_across_dp[0]) == num_tokens // 2
assert cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
if not should_ubatch:
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs.
@ -2888,10 +2868,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# for GQA/MQA.
max_query_len = self.uniform_decode_query_len if uniform_decode else \
num_tokens
if allow_microbatching:
assert self.uniform_decode_query_len == 1
assert uniform_decode is True
assert max_query_len == 1
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
@ -2930,20 +2906,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
ubatch_slices = None
num_tokens_after_padding = None
# We currently only microbatch if the number of tokens is
# over a certain threshold.
if should_ubatch:
# We only support decode-only cudagraphs
assert num_reqs == num_tokens
assert num_tokens % 2 == 0
ubatch_slices = [
UBatchSlice(slice(0, num_reqs // 2), slice(0,
num_tokens // 2)),
UBatchSlice(slice(num_reqs // 2, num_reqs),
slice(num_tokens // 2, num_tokens))
]
if self.parallel_config.enable_dbo and allow_microbatching:
ubatch_slices, num_tokens_after_padding = ubatch_split(
num_scheduled_tokens,
total_num_scheduled_tokens,
total_num_scheduled_tokens,
self.vllm_config,
)
# If we failed to microbatch, currently need to resynchronize
# TODO(lucas,sage): we should be able to avoid this second sync by
# refactoring `get_dp_padding_ubatch` and `get_dp_padding` into
# a single `coordinate_batch_across_dp` function.
if num_tokens_after_padding is None:
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens_after_padding = num_tokens + num_pad
else:
num_tokens_across_dp = num_tokens_after_padding
num_tokens_after_padding = int(num_tokens_after_padding[0].item())
attn_metadata: Optional[PerLayerAttnMetadata] = None
@ -2966,6 +2953,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()
cum_num_tokens, _ = self._get_cumsum_and_arange(
num_scheduled_tokens)
self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens
self.query_start_loc.copy_to_gpu()
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
common_attn_metadata = CommonAttentionMetadata(
@ -3060,7 +3052,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_randomize_inputs(input_ids), set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens=num_tokens_after_padding,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
@ -3395,38 +3387,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
desc="Capturing CUDA graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
cudagraph_runtime_mode.name))
enable_dbo = self.parallel_config.enable_dbo
# DBO Only supports running Full cudagraphs with uniform
# decode lengths
if enable_dbo and uniform_decode:
for num_tokens in compilation_cases:
# If the number of tokens is greater than the microbatching
# threshold, don't generate a microbatched cudagraph
if (num_tokens
< self.parallel_config.dbo_decode_token_threshold):
continue
# Warmup
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):
force_attention = (
cudagraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=True,
allow_microbatching=True,
skip_eplb=True)
# Graph Capture
self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True,
allow_microbatching=True,
skip_eplb=True)
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# We currently only capture ubatched graphs when its a FULL
# cudagraph and for uniform decode batches.
capture_ubatched_graph = self.parallel_config.enable_dbo \
and cudagraph_runtime_mode == CUDAGraphMode.FULL \
and uniform_decode \
and check_ubatch_thresholds(
config=self.vllm_config.parallel_config,
num_tokens=num_tokens,
uniform_decode=uniform_decode,
)
# Currently we capture both microbatched and non-microbatched
# graphs when capture_ubatched_graph is True, this is because
# occasionally we will be forced out of microbatching due to other
# DP ranks not microbatching (usually caused by an empty second
# microbatch; once we resolve this, we can remove the
# non-microbatched graph capture).
allow_microbatching_options = [True, False] if \
capture_ubatched_graph else [False]
for allow_microbatching in allow_microbatching_options:
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
@ -3438,11 +3423,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False)
self.maybe_remove_all_loras(self.lora_config)
@ -3500,24 +3487,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_groups: list[AttentionGroup] = []
for (attn_backend,
kv_cache_spec), layer_names in attn_backends_map.items():
attn_metadata_builders = []
attn_metadata_builders.append(attn_backend.get_builder_cls()(
kv_cache_spec,
attn_group = AttentionGroup.create_with_metadata_builders(
attn_backend,
layer_names,
kv_cache_spec,
self.vllm_config,
self.device,
))
if self.parallel_config.enable_dbo:
attn_metadata_builders.append(
attn_backend.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
))
attn_group = AttentionGroup(attn_backend,
attn_metadata_builders,
layer_names, kv_cache_spec)
num_metadata_builders=1
if not self.parallel_config.enable_dbo else 2,
)
attn_groups.append(attn_group)
return attn_groups

View File

@ -1,25 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import threading
from dataclasses import dataclass
from typing import Any, Callable, Optional
import torch
import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_ep_group
from vllm.forward_context import (create_forward_context, get_forward_context,
override_forward_context)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import has_deep_gemm
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
logger = init_logger(__name__)
@dataclasses.dataclass
@dataclass
class UbatchMetadata:
context: UBatchContext
input_ids: torch.Tensor
@ -29,13 +32,55 @@ class UbatchMetadata:
num_tokens: int
@dataclasses.dataclass
@dataclass
class CUDAGraphMetaData:
cudagraph: torch.cuda.CUDAGraph
ubatch_metadata: UbatchMetadata
outputs: Optional[Any] = None
class SMControlContextManager:
def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None],
set_compute_sms: Callable[[int], None]):
"""
Context manager for controlling SM (Streaming Multiprocessor)
allocation. Upon entering the context, it sets the number of SMs
allocated for communication and computation to comm_sms and
total_sms - comm_sms respectively. Upon exiting, it restores the
allocation to use all available SMs (i.e. total_sms).
Args:
comm_sms (int): The number of SMs to allocate for communication.
(The remainder will be used for computation.)
set_comm_sms (Callable[[int], None]):
A function that sets the number of SMs for communication.
set_compute_sms (Callable[[int], None]):
A function that sets the number of SMs for computation.
"""
assert current_platform.is_cuda(), \
"SM control is currently only supported on CUDA"
props = torch.cuda.get_device_properties(torch.cuda.current_device())
total_sms = props.multi_processor_count
assert comm_sms < total_sms
self.total_sms = total_sms
self.compute_sms = total_sms - comm_sms
self.comm_sms = comm_sms
self.set_comm_sms = set_comm_sms
self.set_compute_sms = set_compute_sms
def __enter__(self):
self.set_comm_sms(self.comm_sms)
self.set_compute_sms(self.compute_sms)
def __exit__(self, exc_type, exc_value, traceback):
self.set_comm_sms(self.total_sms)
self.set_compute_sms(self.total_sms)
class UBatchWrapper:
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
@ -56,6 +101,35 @@ class UBatchWrapper:
runnable, vllm_config, runtime_mode=runtime_mode)
self.graph_pool = current_platform.get_global_graph_pool()
self.sm_control = self._create_sm_control_context(vllm_config)
@staticmethod
def _create_sm_control_context(vllm_config: VllmConfig):
comm_sms = envs.VLLM_DBO_COMM_SMS
set_comm_sms = lambda sms: None
if vllm_config.parallel_config.enable_expert_parallel:
# Currently only DeepEP highthroughput supports SM control so this
# only affects that case.
all2all_manager = get_ep_group(
).device_communicator.all2all_manager
if all2all_manager.max_sms_used() is not None:
comm_sms = min(comm_sms, all2all_manager.max_sms_used())
if comm_sms > 0:
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
# TODO(lucas): support other kernels besides DeepGEMM
set_compute_sms = lambda sms: None
if has_deep_gemm() and comm_sms > 0:
import deep_gemm as dg
set_compute_sms = lambda sms: dg.set_num_sms(sms)
return SMControlContextManager(comm_sms=comm_sms,
set_comm_sms=set_comm_sms,
set_compute_sms=set_compute_sms)
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
@ -282,7 +356,7 @@ class UBatchWrapper:
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE)
with self.sm_control:
return self._capture_ubatches(ubatch_metadata, self.model)
elif num_tokens in self.cudagraphs:
cudagraph_metadata = self.cudagraphs[num_tokens]
@ -300,4 +374,5 @@ class UBatchWrapper:
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE)
with self.sm_control:
return self._run_ubatches(ubatch_metadata, self.model)

View File

@ -3,9 +3,10 @@
from typing import Optional
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.config import ParallelConfig, VllmConfig
from vllm.forward_context import DPMetadata
from vllm.logger import init_logger
from vllm.utils import round_up
@ -29,6 +30,16 @@ def should_ubatch_with_num_tokens(
dp_size, dp_rank)
def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int,
uniform_decode: bool) -> bool:
if not config.enable_dbo:
return False
if uniform_decode:
return num_tokens >= config.dbo_decode_token_threshold
else:
return num_tokens >= config.dbo_prefill_token_threshold
def get_dp_padding_ubatch(
num_tokens_unpadded: int, num_tokens_padded: int,
should_attempt_ubatching: bool,
@ -95,9 +106,37 @@ def get_dp_padding_ubatch(
dtype=torch.int32)
return should_ubatch, num_tokens_after_padding
def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \
-> UBatchSlices:
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
first_ubatch_token_slice = slice(0, split_point)
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
# Determine request slices using exclusive stop semantics
# First ubatch includes requests whose tokens overlap [0, split_point)
first_ubatch_req_stop = int(
np.searchsorted(cu_num_tokens, split_point, side="left"))
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
# Second ubatch starts at the request that contains the split_point
# or the request starting exactly at split_point (if on boundary)
second_ubatch_req_start = int(
np.searchsorted(cu_num_tokens, split_point, side="right") - 1)
second_ubatch_req_slice = slice(second_ubatch_req_start,
len(cu_num_tokens) - 1)
return [
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice)
]
def ubatch_split(
max_num_scheduled_tokens: int,
num_scheduled_tokens_per_request: np.ndarray,
num_tokens_unpadded: int,
num_tokens_padded: int,
vllm_config: VllmConfig,
@ -122,17 +161,20 @@ def ubatch_split(
return (None, None)
# Check preconditions for microbatching
should_attempt_ubatching = \
parallel_config.enable_dbo and \
num_tokens_unpadded >= \
parallel_config.dbo_decode_token_threshold \
and max_num_scheduled_tokens == 1
should_attempt_ubatching = check_ubatch_thresholds(
parallel_config,
num_tokens_unpadded,
vllm_config,
)
# Don't microbatch unless every other DP worker is also microbatching
num_tokens_after_padding = None
(should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch(
num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching,
vllm_config)
should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
vllm_config,
)
if not should_ubatch:
return (None, None)
@ -141,15 +183,9 @@ def ubatch_split(
# to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation
assert num_tokens_after_padding is not None
total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item())
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
num_tokens_unpadded)
token_split_point = int(num_tokens_after_padding[0].item())
# Note there's an assumption here that there's 1 token per request
ubatch_slices = [
UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice),
UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice)
]
ubatch_slices = create_ubatch_slices(num_scheduled_tokens_per_request,
token_split_point)
return (ubatch_slices, num_tokens_after_padding)

View File

@ -10,6 +10,14 @@ class UBatchSlice:
request_slice: slice
token_slice: slice
def is_empty(self) -> bool:
return self.request_slice.start == self.request_slice.stop \
or self.token_slice.start == self.token_slice.stop
@property
def num_tokens(self) -> int:
return self.token_slice.stop - self.token_slice.start
UBatchSlices: TypeAlias = list[UBatchSlice]

View File

@ -51,8 +51,8 @@ class UBatchContext:
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
# Assume we start on the compute stream
assert current_stream() == self.compute_stream
# Assume we want to start on the compute stream
self.update_stream(self.compute_stream)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
@ -62,16 +62,14 @@ class UBatchContext:
self.maybe_run_recv_hook()
self.cpu_signal_event.set()
self.cpu_wait_event.clear()
self.current_stream = self.compute_stream
torch.cuda.set_stream(self.current_stream)
return False
def _restore_context(self):
forward_context._forward_context = self.forward_context
torch.cuda.set_stream(self.current_stream)
def update_stream(self, stream):
self.current_stream = stream
if current_stream() != self.current_stream:
torch.cuda.set_stream(self.current_stream)
def _signal_comm_done(self):
@ -99,9 +97,20 @@ class UBatchContext:
self.cpu_wait_event.clear()
self._restore_context()
def switch_to_comm(self):
self.update_stream(self.comm_stream)
def switch_to_compute(self):
self.update_stream(self.compute_stream)
def switch_to_comm_sync(self):
self._signal_compute_done()
self.update_stream(self.comm_stream)
self._wait_compute_done()
def switch_to_compute_sync(self):
self._signal_comm_done()
self.update_stream(self.compute_stream)
self._wait_comm_done()
def maybe_run_recv_hook(self):
@ -112,7 +121,6 @@ class UBatchContext:
def yield_(self):
self.current_stream = current_stream()
self._cpu_yield()
if self.current_stream != current_stream():
self.update_stream(self.current_stream)
def yield_and_switch_from_compute_to_comm(self):
@ -153,15 +161,20 @@ def _register_ubatch_function(func):
return wrapper
dbo_maybe_run_recv_hook = _register_ubatch_function(
UBatchContext.maybe_run_recv_hook)
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
UBatchContext.yield_and_switch_from_compute_to_comm)
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
UBatchContext.yield_and_switch_from_comm_to_compute)
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
dbo_maybe_run_recv_hook = _register_ubatch_function(
UBatchContext.maybe_run_recv_hook)
dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm)
dbo_switch_to_compute = _register_ubatch_function(
UBatchContext.switch_to_compute)
dbo_switch_to_comm_sync = _register_ubatch_function(
UBatchContext.switch_to_comm_sync)
dbo_switch_to_compute_sync = _register_ubatch_function(
UBatchContext.switch_to_compute_sync)
def dbo_register_recv_hook(recv_hook):

View File

@ -130,15 +130,32 @@ class MultiModalBudget:
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str]
kv_cache_spec: KVCacheSpec
@staticmethod
def create_with_metadata_builders(
backend: type[AttentionBackend],
layer_names: list[str],
kv_cache_spec: KVCacheSpec,
vllm_config: VllmConfig,
device: torch.device,
num_metadata_builders: int = 1,
) -> 'AttentionGroup':
metadata_builders = [
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config,
device)
for _ in range(num_metadata_builders)
]
return AttentionGroup(backend, metadata_builders, layer_names,
kv_cache_spec)
def get_metadata_builder(self,
ubatch_id: Optional[int] = None
) -> AttentionMetadataBuilder:
if ubatch_id is None:
return self.metadata_builders[0]
ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id
return self.metadata_builders[ubatch_id]