mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
[Core/DBO][1/N] Add Dual-Batch Overlap mechanism to VLLM (#23693)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
08369289af
commit
567939953b
@ -87,6 +87,11 @@ def parse_args():
|
||||
default=0.8,
|
||||
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-dbo",
|
||||
action="store_true",
|
||||
help=("Enable microbatched execution"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compilation-config",
|
||||
type=int,
|
||||
@ -113,6 +118,7 @@ def main(
|
||||
max_model_len,
|
||||
compilation_config,
|
||||
gpu_memory_utilization,
|
||||
enable_dbo,
|
||||
quantization,
|
||||
):
|
||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||
@ -167,6 +173,7 @@ def main(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enable_dbo=enable_dbo,
|
||||
quantization=quantization,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
@ -227,6 +234,7 @@ if __name__ == "__main__":
|
||||
args.max_model_len,
|
||||
args.compilation_config,
|
||||
args.gpu_memory_utilization,
|
||||
args.enable_dbo,
|
||||
args.quantization,
|
||||
),
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
||||
from tests.v1.attention.utils import create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (UbatchSlice,
|
||||
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata)
|
||||
@ -106,7 +106,7 @@ def mixed_small_metadata():
|
||||
def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
|
||||
"""Test slicing decode batch metadata"""
|
||||
# Split first request only
|
||||
ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1))
|
||||
ubatch_slice = UBatchSlice(slice(0, 1), slice(0, 1))
|
||||
|
||||
result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata)
|
||||
|
||||
@ -120,7 +120,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
|
||||
|
||||
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
|
||||
"""Test slicing mixed batch metadata"""
|
||||
ubatch_slice = UbatchSlice(slice(1, 3),
|
||||
ubatch_slice = UBatchSlice(slice(1, 3),
|
||||
slice(1, 7)) # Requests 1-3, tokens 1-7
|
||||
|
||||
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
|
||||
@ -137,8 +137,8 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
num_tokens = large_decode_metadata.num_reqs
|
||||
mid_point = num_tokens // 2
|
||||
ubatch_slices = [
|
||||
UbatchSlice(slice(0, mid_point), slice(0, mid_point)),
|
||||
UbatchSlice(slice(mid_point, num_tokens), slice(mid_point,
|
||||
UBatchSlice(slice(0, mid_point), slice(0, mid_point)),
|
||||
UBatchSlice(slice(mid_point, num_tokens), slice(mid_point,
|
||||
num_tokens)),
|
||||
]
|
||||
|
||||
|
||||
@ -365,7 +365,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
# Mock runner for attention metadata building
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [
|
||||
attn_metadata_builder
|
||||
]
|
||||
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
@ -489,7 +491,9 @@ def test_propose_tree(spec_token_tree):
|
||||
# Mock runner for attention metadata building.
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [
|
||||
attn_metadata_builder
|
||||
]
|
||||
|
||||
# Setup inputs for the proposer.
|
||||
target_token_ids = torch.randint(0,
|
||||
|
||||
@ -2848,6 +2848,14 @@ class VllmConfig:
|
||||
"when cudagraph_mode piecewise cudagraphs is used, "\
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
|
||||
if self.parallel_config.enable_dbo:
|
||||
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
assert a2a_backend == "deepep_low_latency", \
|
||||
"Microbatching currently only supports the deepep_low_latency "\
|
||||
f"all2all backend. {a2a_backend} is not supported. To fix set "\
|
||||
"the VLLM_ALL2ALL_BACKEND environment variable to "\
|
||||
"deepep_low_latency and install the DeepEP kerenls."
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
|
||||
@ -137,6 +137,14 @@ class ParallelConfig:
|
||||
disable_custom_all_reduce: bool = False
|
||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||
|
||||
enable_dbo: bool = False
|
||||
"""Enable microbatching for the model executor."""
|
||||
|
||||
dbo_decode_token_threshold: int = 32
|
||||
"""The threshold for microbatching. If the number of tokens in the
|
||||
request is greater than this threshold, microbatching will be used.
|
||||
Otherwise, the request will be processed in a single batch."""
|
||||
|
||||
ray_workers_use_nsight: bool = False
|
||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||
|
||||
|
||||
@ -251,9 +251,4 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
# It is dangerous to set num sms outside this function. num_sms is not
|
||||
# a part of the hash-key that identifies this object. If we are in a
|
||||
# situation where we make objects with different num_sms, the hash key
|
||||
# in get_or_create must be updated.
|
||||
handle.set_num_sms(self.num_sms)
|
||||
return handle
|
||||
|
||||
@ -327,6 +327,9 @@ class EngineArgs:
|
||||
data_parallel_hybrid_lb: bool = False
|
||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||
dbo_decode_token_threshold: int = \
|
||||
ParallelConfig.dbo_decode_token_threshold
|
||||
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
|
||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||
expert_placement_strategy: ExpertPlacementStrategy = \
|
||||
@ -695,6 +698,11 @@ class EngineArgs:
|
||||
parallel_group.add_argument(
|
||||
"--enable-expert-parallel",
|
||||
**parallel_kwargs["enable_expert_parallel"])
|
||||
parallel_group.add_argument("--enable-dbo",
|
||||
**parallel_kwargs["enable_dbo"])
|
||||
parallel_group.add_argument(
|
||||
"--dbo-decode-token-threshold",
|
||||
**parallel_kwargs["dbo_decode_token_threshold"])
|
||||
parallel_group.add_argument("--enable-eplb",
|
||||
**parallel_kwargs["enable_eplb"])
|
||||
parallel_group.add_argument("--eplb-config",
|
||||
@ -1339,6 +1347,8 @@ class EngineArgs:
|
||||
data_parallel_backend=self.data_parallel_backend,
|
||||
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
enable_dbo=self.enable_dbo,
|
||||
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
||||
enable_eplb=self.enable_eplb,
|
||||
eplb_config=self.eplb_config,
|
||||
expert_placement_strategy=self.expert_placement_strategy,
|
||||
|
||||
@ -14,6 +14,7 @@ import vllm.envs as envs
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@ -97,6 +98,53 @@ class DPMetadata:
|
||||
dist.all_reduce(num_tokens_tensor, group=group)
|
||||
return num_tokens_tensor.cpu()
|
||||
|
||||
@staticmethod
|
||||
def should_ubatch_across_dp(
|
||||
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int, dp_size: int,
|
||||
dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
"""
|
||||
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||
run with microbatching or none of them do. If this function decides
|
||||
not to run with microbatching. It will "abort" meaning that no padding
|
||||
information will be returned to the caller. It will return (False, None)
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that the run with the same number
|
||||
of tokens
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if should_ubatch if False
|
||||
]
|
||||
"""
|
||||
|
||||
device = current_platform.device_type
|
||||
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
|
||||
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
|
||||
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
|
||||
tensor[2][dp_rank] = 1 if should_ubatch else 0
|
||||
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
dist.all_reduce(tensor, group=get_dp_group().device_group)
|
||||
|
||||
result: bool = bool(torch.all(tensor[2] == 1).item())
|
||||
if not result:
|
||||
return result, None
|
||||
|
||||
orig_num_tokens_tensor = tensor[0, :]
|
||||
padded_num_tokens_tensor = tensor[1, :]
|
||||
|
||||
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
|
||||
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
|
||||
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
|
||||
logger.debug("Aborting ubatching %s %s", orig_min_num_tokens,
|
||||
padded_max_num_tokens)
|
||||
return False, None
|
||||
return result, padded_num_tokens_tensor.cpu()
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
parallel_config: ParallelConfig,
|
||||
@ -119,14 +167,15 @@ class DPMetadata:
|
||||
|
||||
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
||||
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
||||
assert (num_tokens_across_dp is None
|
||||
or num_tokens_across_dp[dp_rank] == batchsize)
|
||||
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
|
||||
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
|
||||
if num_tokens_across_dp is None:
|
||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||
batchsize, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
|
||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
|
||||
num_tokens_across_dp)
|
||||
|
||||
@contextmanager
|
||||
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
|
||||
@ -179,9 +228,12 @@ class ForwardContext:
|
||||
Type AttentionMetadata for v0,
|
||||
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
|
||||
attention layer to its attention metadata
|
||||
set dynamically for each forward pass
|
||||
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
|
||||
for each microbatch.
|
||||
Set dynamically for each forward pass
|
||||
"""
|
||||
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
|
||||
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"],
|
||||
list[dict[str, "AttentionMetadata"]]]
|
||||
# TODO: remove after making all virtual_engines share the same kv cache
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
# set dynamically for each forward pass
|
||||
@ -191,6 +243,8 @@ class ForwardContext:
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
|
||||
batch_descriptor: Optional[BatchDescriptor] = None
|
||||
|
||||
ubatch_slices: Optional[UBatchSlices] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.cudagraph_runtime_mode in [
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
||||
@ -208,6 +262,39 @@ def get_forward_context() -> ForwardContext:
|
||||
return _forward_context
|
||||
|
||||
|
||||
def create_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
dp_metadata: Optional[DPMetadata] = None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
ubatch_slices: Optional[UBatchSlices] = None):
|
||||
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
ubatch_slices=ubatch_slices)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_forward_context(forward_context: Optional[ForwardContext]):
|
||||
"""A context manager that overrides the current forward context.
|
||||
This is used to override the forward context for a specific
|
||||
forward pass.
|
||||
"""
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = forward_context
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_forward_context = prev_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(
|
||||
attn_metadata: Any,
|
||||
@ -216,7 +303,8 @@ def set_forward_context(
|
||||
num_tokens: Optional[int] = None,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None):
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
ubatch_slices: Optional[UBatchSlices] = None):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
@ -225,6 +313,7 @@ def set_forward_context(
|
||||
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
||||
if need_to_track_batchsize:
|
||||
forward_start_time = time.perf_counter()
|
||||
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
if vllm_config.parallel_config.data_parallel_size > 1 and (
|
||||
attn_metadata is not None or num_tokens is not None):
|
||||
@ -232,20 +321,14 @@ def set_forward_context(
|
||||
attn_metadata, num_tokens or 0,
|
||||
num_tokens_across_dp)
|
||||
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = ForwardContext(
|
||||
no_compile_layers=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
)
|
||||
forward_context = create_forward_context(attn_metadata, vllm_config,
|
||||
virtual_engine, dp_metadata,
|
||||
cudagraph_runtime_mode,
|
||||
batch_descriptor, ubatch_slices)
|
||||
|
||||
try:
|
||||
yield
|
||||
with override_forward_context(forward_context):
|
||||
yield
|
||||
finally:
|
||||
global last_logging_time, batchsize_logging_interval
|
||||
if need_to_track_batchsize:
|
||||
@ -282,5 +365,3 @@ def set_forward_context(
|
||||
logger.info(("Batchsize forward time stats "
|
||||
"(batchsize, count, median_time(ms)): %s"),
|
||||
forward_stats)
|
||||
|
||||
_forward_context = prev_context
|
||||
|
||||
@ -191,7 +191,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> Callable:
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
@ -217,13 +217,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1q_scale = None
|
||||
a1_post_scale = a1_scale
|
||||
|
||||
return self._do_dispatch(tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts,
|
||||
a1_scale=a1_post_scale,
|
||||
quant_config=quant_config)
|
||||
return (lambda *args: None,
|
||||
self._do_dispatch(tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts,
|
||||
a1_scale=a1_post_scale,
|
||||
quant_config=quant_config))
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@ -237,10 +238,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
|
||||
topk_ids, num_experts, expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
(_, receiver) = self.prepare_async(a1, a1_scale, a2_scale,
|
||||
topk_weights, topk_ids, num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
|
||||
@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input, normalize_batched_scales_shape)
|
||||
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
dbo_register_recv_hook, dbo_yield)
|
||||
|
||||
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||
@ -55,7 +58,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
self.handle = None
|
||||
self.handles: list[Optional[tuple]] = [None, None]
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
@ -123,13 +126,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.ReceiverType:
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
|
||||
hidden_size = a1.size(1)
|
||||
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
|
||||
(f"Hidden Size {hidden_size} not in supported list of hidden sizes"
|
||||
f"{self.SUPPORTED_HIDDEN_SIZES}")
|
||||
|
||||
a2a_idx = dbo_current_ubatch_id()
|
||||
|
||||
if self.use_fp8_dispatch:
|
||||
assert hidden_size % 128 == 0, \
|
||||
"DeepEP kernels quantize the inputs in blocks of shape 128"
|
||||
@ -148,7 +153,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Dispatch
|
||||
expert_x, expert_num_tokens, self.handle, event, hook = \
|
||||
expert_x, expert_num_tokens, handle, _, hook= \
|
||||
self.buffer.low_latency_dispatch(a1,
|
||||
topk_ids,
|
||||
self.max_tokens_per_rank,
|
||||
@ -156,21 +161,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
async_finish=False,
|
||||
return_recv_hook=True)
|
||||
self.handles[a2a_idx] = handle
|
||||
|
||||
return lambda: self._receiver(hook, expert_x, expert_num_tokens,
|
||||
a1_scale, a1.dtype, quant_config)
|
||||
return (hook, lambda: self._receiver(expert_x, expert_num_tokens,
|
||||
a1_scale, a1.dtype, quant_config))
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
hook: Callable,
|
||||
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
expert_num_tokens: torch.Tensor,
|
||||
a1_scale,
|
||||
a1_dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
hook()
|
||||
|
||||
expert_x, expert_x_scale = self._do_quant(
|
||||
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||
@ -192,10 +195,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
|
||||
topk_ids, num_experts, expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
hook, receiver = self.prepare_async(a1, a1_scale, a2_scale,
|
||||
topk_weights, topk_ids,
|
||||
num_experts, expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
hook()
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
@ -210,7 +215,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
assert isinstance(
|
||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||
), ("Weight application and reduction happens in the combine kernel.")
|
||||
assert self.handle is not None
|
||||
|
||||
a2a_idx = dbo_current_ubatch_id()
|
||||
do_recv_hook = dbo_enabled()
|
||||
handle = self.handles[a2a_idx]
|
||||
assert handle is not None
|
||||
|
||||
combine_topk_weights = topk_weights
|
||||
if apply_router_weight_on_input:
|
||||
@ -218,12 +227,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
combine_topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
# TODO (varun) : Enable zero copy mode
|
||||
_, event, hook = self.buffer.low_latency_combine(
|
||||
dbo_maybe_run_recv_hook()
|
||||
_, _, recv_hook = self.buffer.low_latency_combine(
|
||||
fused_expert_output,
|
||||
topk_ids,
|
||||
combine_topk_weights,
|
||||
self.handle,
|
||||
handle,
|
||||
async_finish=False,
|
||||
zero_copy=False,
|
||||
return_recv_hook=False,
|
||||
return_recv_hook=do_recv_hook,
|
||||
out=output)
|
||||
if recv_hook is not None:
|
||||
dbo_register_recv_hook(recv_hook)
|
||||
dbo_yield()
|
||||
|
||||
@ -38,6 +38,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
|
||||
round_up)
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
@ -992,16 +993,28 @@ class FusedMoE(CustomOp):
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_config.use_flashinfer_cutlass_kernels):
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(moe.max_num_tokens, self.hidden_size),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
if vllm_config.parallel_config.enable_dbo:
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(2, moe.max_num_tokens, self.hidden_size),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(moe.max_num_tokens, num_experts),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(2, moe.max_num_tokens, num_experts),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
else:
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(moe.max_num_tokens, self.hidden_size),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(moe.max_num_tokens, num_experts),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> Optional[torch.nn.Module]:
|
||||
@ -1708,14 +1721,29 @@ class FusedMoE(CustomOp):
|
||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||
|
||||
assert (self.batched_hidden_states.size(0) # type: ignore
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
# This is only true when DBO has been enabled in the config.
|
||||
# Both tensors will have an outer dimension for the ubatch id
|
||||
if self.batched_hidden_states.dim() == 3:
|
||||
assert self.batched_router_logits.dim() == 3
|
||||
batch_buffer_idx = dbo_current_ubatch_id()
|
||||
batched_hidden_states = self.batched_hidden_states[
|
||||
batch_buffer_idx, :]
|
||||
batched_router_logits = self.batched_router_logits[
|
||||
batch_buffer_idx, :]
|
||||
else:
|
||||
batched_hidden_states = self.batched_hidden_states
|
||||
batched_router_logits = self.batched_router_logits
|
||||
|
||||
assert (batched_hidden_states.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
assert (self.batched_router_logits.size(0) # type: ignore
|
||||
assert (batched_router_logits.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
staged_hidden_states = self.batched_hidden_states[:
|
||||
chunk_size, :] # type: ignore
|
||||
staged_router_logits = self.batched_router_logits[:
|
||||
chunk_size, :] # type: ignore
|
||||
staged_hidden_states = batched_hidden_states[:
|
||||
chunk_size, :] # type: ignore
|
||||
staged_router_logits = batched_router_logits[:
|
||||
chunk_size, :] # type: ignore
|
||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||
|
||||
|
||||
@ -13,6 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||
_resize_cache, count_expert_num_tokens)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
|
||||
dbo_register_recv_hook, dbo_yield)
|
||||
|
||||
#
|
||||
# This file defines a set of base classes used to make MoE kernels more modular.
|
||||
@ -226,7 +228,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> ReceiverType:
|
||||
) -> tuple[Callable, ReceiverType]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel
|
||||
but do not wait for results from other workers.
|
||||
@ -496,6 +498,23 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
return None
|
||||
|
||||
|
||||
class SharedResizableBuffer:
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = None
|
||||
|
||||
def get(self, shape: tuple[int, ...], device: torch.device,
|
||||
dtype: torch.dtype):
|
||||
shape_numel = prod(shape)
|
||||
if self.buffer is None or self.buffer.numel() < shape_numel:
|
||||
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
|
||||
assert self.buffer.device == device, \
|
||||
f"Buffer device mismatch: {self.buffer.device} != {device}"
|
||||
assert self.buffer.dtype == dtype, \
|
||||
f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}"
|
||||
return self.buffer[:shape_numel].view(*shape)
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
@ -509,6 +528,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
layer due to any layer specific state that may be used by the component
|
||||
objects.
|
||||
"""
|
||||
fused_out_buffer = SharedResizableBuffer()
|
||||
workspace13_buffer = SharedResizableBuffer()
|
||||
workspace2_buffer = SharedResizableBuffer()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -559,12 +581,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.empty(prod(workspace2_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace13 = self.workspace13_buffer.get(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = self.workspace2_buffer.get(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||
@ -656,9 +678,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
fused_out = torch.empty(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
fused_out = self.fused_out_buffer.get(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int
|
||||
@ -801,8 +823,10 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
shared_output: torch.Tensor
|
||||
|
||||
if (not self.prepare_finalize.supports_async()
|
||||
or self.shared_experts is None):
|
||||
if not self.prepare_finalize.supports_async():
|
||||
# We shouldn't be running an a2a kernel that doesn't
|
||||
# support async prepare/finalize
|
||||
assert not dbo_enabled()
|
||||
|
||||
# Run shared experts serially with dispatch.
|
||||
if self.shared_experts is not None:
|
||||
@ -822,7 +846,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
)
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
receiver = self.prepare_finalize.prepare_async(
|
||||
dbo_maybe_run_recv_hook()
|
||||
hook, receiver = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
@ -834,8 +859,16 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
assert self.shared_experts is not None
|
||||
shared_output = self.shared_experts(a1)
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
# If DBO is being used, register the hook with the ubatch context
|
||||
# and call it in dbo_maybe_run_recv_hook instead of passing it to
|
||||
# the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
if not dbo_enabled():
|
||||
hook()
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = receiver()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import pplx_kernels as pplx
|
||||
import torch
|
||||
@ -103,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.ReceiverType:
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
@ -214,30 +214,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
do_recv=False,
|
||||
)
|
||||
|
||||
return lambda: self._receiver(
|
||||
expert_num_tokens,
|
||||
expert_x,
|
||||
expert_x_scale,
|
||||
a1q,
|
||||
a1q_scale,
|
||||
topk_ids,
|
||||
bound_m,
|
||||
orig_a_scale_block_shape,
|
||||
)
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
expert_x: torch.Tensor,
|
||||
expert_x_scale: Optional[torch.Tensor],
|
||||
a1q: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
topk_ids: torch.Tensor,
|
||||
bound_m: Optional[torch.Tensor],
|
||||
orig_a_scale_block_shape: Optional[int],
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
self.a2a.dispatch(
|
||||
hook = lambda: self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
@ -249,6 +226,21 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
do_recv=True,
|
||||
)
|
||||
|
||||
return (hook, lambda: self._receiver(
|
||||
expert_num_tokens,
|
||||
expert_x,
|
||||
expert_x_scale,
|
||||
orig_a_scale_block_shape,
|
||||
))
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
expert_x: torch.Tensor,
|
||||
expert_x_scale: Optional[torch.Tensor],
|
||||
orig_a_scale_block_shape: Optional[int],
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
if expert_x_scale is not None:
|
||||
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
|
||||
assert expert_x_scale.ndim == 3
|
||||
@ -270,7 +262,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
receiver = self.prepare_async(
|
||||
hook, receiver = self.prepare_async(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
@ -281,6 +273,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input,
|
||||
quant_config,
|
||||
)
|
||||
hook()
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
|
||||
@ -28,6 +28,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice
|
||||
|
||||
logger = init_logger(__name__)
|
||||
KVCacheLayoutType = Literal["NHD", "HND"]
|
||||
@ -81,12 +82,6 @@ class CommonAttentionMetadata:
|
||||
encoder_seq_lens: Optional[np.ndarray] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UbatchSlice:
|
||||
request_slice: slice
|
||||
token_slice: slice
|
||||
|
||||
|
||||
def slice_query_start_locs(
|
||||
query_start_loc: torch.Tensor,
|
||||
request_slice: slice,
|
||||
@ -103,7 +98,7 @@ def slice_query_start_locs(
|
||||
|
||||
|
||||
def _make_metadata_with_slice(
|
||||
ubatch_slice: UbatchSlice,
|
||||
ubatch_slice: UBatchSlice,
|
||||
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
|
||||
"""
|
||||
This function creates a new CommonAttentionMetadata that corresponds to
|
||||
@ -133,6 +128,11 @@ def _make_metadata_with_slice(
|
||||
torch.max(torch.abs(query_start_loc_cpu[1:] -
|
||||
query_start_loc_cpu[:-1])).item())
|
||||
|
||||
# This is to account for the case where we are in a dummy
|
||||
# run and query_start_loc_cpu is full of 0s
|
||||
if max_query_len == 0:
|
||||
max_query_len = attn_metadata.max_query_len
|
||||
|
||||
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
||||
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
||||
|
||||
@ -152,12 +152,12 @@ def _make_metadata_with_slice(
|
||||
|
||||
|
||||
def split_attn_metadata(
|
||||
ubatch_slices: list[UbatchSlice],
|
||||
ubatch_slices: list[UBatchSlice],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[CommonAttentionMetadata]:
|
||||
"""
|
||||
Creates a new CommonAttentionMetadata instance that corresponds to the
|
||||
requests for each UbatchSlice in ubatch_slices.
|
||||
requests for each UBatchSlice in ubatch_slices.
|
||||
|
||||
Note: This function does not modify common_attn_metadata
|
||||
"""
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -179,9 +180,11 @@ class EagleProposer:
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
|
||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0)
|
||||
ubatch_id = dbo_current_ubatch_id()
|
||||
attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0)
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
@ -355,8 +358,9 @@ class EagleProposer:
|
||||
hidden_states: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[torch.Tensor]:
|
||||
ubatch_id = dbo_current_ubatch_id()
|
||||
tree_attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builder
|
||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
assert isinstance(tree_attn_metadata_builder,
|
||||
TreeAttentionMetadataBuilder)
|
||||
|
||||
|
||||
@ -64,8 +64,13 @@ class CPUModelRunner(GPUModelRunner):
|
||||
if not self.attn_groups[0]:
|
||||
return
|
||||
|
||||
mb = getattr(self.attn_groups[0][0], "metadata_builder", None)
|
||||
if not isinstance(mb, TorchSDPAMetadataBuilderV1):
|
||||
mb = getattr(self.attn_groups[0][0], "metadata_builders", None)
|
||||
if isinstance(mb, list):
|
||||
if not isinstance(mb[0], TorchSDPAMetadataBuilderV1):
|
||||
return
|
||||
mb[0].reorder_batch(self.input_batch, scheduler_output)
|
||||
return
|
||||
elif not isinstance(mb, TorchSDPAMetadataBuilderV1):
|
||||
# Encoder-only / rerank models do not benefit from reordering,
|
||||
# so we safely skip here.
|
||||
return
|
||||
|
||||
@ -15,6 +15,7 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention, AttentionType
|
||||
@ -55,11 +56,12 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||
is_pin_memory_available, round_up, supports_dynamo)
|
||||
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
create_fast_prefill_custom_backend,
|
||||
reorder_batch_to_split_decodes_and_prefills)
|
||||
reorder_batch_to_split_decodes_and_prefills, split_attn_metadata)
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -85,9 +87,12 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
|
||||
from .utils import (AttentionGroup, MultiModalBudget,
|
||||
@ -105,6 +110,11 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
|
||||
# list when ubatching is enabled
|
||||
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
||||
AttnMetadataDict]
|
||||
|
||||
|
||||
# Wrapper for ModelRunnerOutput to support overlapped execution.
|
||||
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
@ -274,6 +284,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
self.comm_stream = torch.cuda.Stream()
|
||||
|
||||
# Input Batch
|
||||
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
||||
@ -872,10 +883,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return encoder_seq_lens
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
|
||||
np.ndarray, Optional[CommonAttentionMetadata], int]:
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
||||
Optional[SpecDecodeMetadata], np.ndarray,
|
||||
Optional[CommonAttentionMetadata], int, Optional[UBatchSlices],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
:return: tuple[
|
||||
attn_metadata: layer-to-attention_metadata mapping,
|
||||
@ -947,6 +959,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
|
||||
|
||||
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
|
||||
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
|
||||
num_tokens_unpadded)
|
||||
ubatch_slices, num_tokens_after_padding = \
|
||||
ubatch_split(max_num_scheduled_tokens,
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
self.vllm_config)
|
||||
|
||||
self.seq_lens.np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
@ -1001,7 +1022,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
||||
logits_indices)
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
attn_metadata: PerLayerAttnMetadata = {}
|
||||
if ubatch_slices is not None:
|
||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||
|
||||
# Used in the below loop.
|
||||
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||
@ -1075,7 +1098,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
builder = attn_group.metadata_builder
|
||||
builder = attn_group.get_metadata_builder()
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
@ -1093,13 +1116,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
|
||||
)
|
||||
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args)
|
||||
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
if ubatch_slices is not None:
|
||||
common_attn_metadata_list = split_attn_metadata(
|
||||
ubatch_slices, common_attn_metadata)
|
||||
for ubid, common_attn_metadata in enumerate(
|
||||
common_attn_metadata_list):
|
||||
assert common_attn_metadata.max_query_len == 1
|
||||
attn_metadata_i = (attn_group.get_metadata_builder(
|
||||
ubatch_id=ubid).build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||
else:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args)
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
@ -1107,7 +1144,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
return (attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
||||
max_num_scheduled_tokens)
|
||||
max_num_scheduled_tokens, ubatch_slices,
|
||||
num_tokens_after_padding)
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
@ -1508,7 +1546,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
# get raw model out of the cudagraph wrapper.
|
||||
if isinstance(self.model, CUDAGraphWrapper):
|
||||
if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)):
|
||||
return self.model.unwrap()
|
||||
return self.model
|
||||
|
||||
@ -1675,6 +1713,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def get_dp_padding(self,
|
||||
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that they run with the same number
|
||||
of tokens
|
||||
|
||||
Returns: tuple[
|
||||
num_pad_tokens: The number of tokens that will be added to the batch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens for each DP rank including padding.
|
||||
]
|
||||
"""
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
@ -1698,6 +1747,39 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dtype=torch.int32)
|
||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||
|
||||
def get_local_padding(self, num_tokens_unpadded: int) -> int:
|
||||
|
||||
num_tokens_padded = num_tokens_unpadded
|
||||
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_tokens_padded = self.vllm_config.pad_for_cudagraph(
|
||||
num_tokens_unpadded)
|
||||
else:
|
||||
# Eager mode.
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if self.vllm_config.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism and tp_size > 1:
|
||||
num_tokens_padded = round_up(num_tokens_unpadded, tp_size)
|
||||
|
||||
num_pad_tokens = num_tokens_padded - num_tokens_unpadded
|
||||
return num_pad_tokens
|
||||
|
||||
# This is where the second ubatch is adjusted to account for the padding.
|
||||
# Should be called after attention metadata creation. This just pads
|
||||
# the second ubatch slice out to the total number of tokens
|
||||
# (num_tokens + padding)
|
||||
def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices,
|
||||
num_total_tokens: int):
|
||||
padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start,
|
||||
num_total_tokens)
|
||||
ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice,
|
||||
padded_second_ubatch_slice)
|
||||
|
||||
def _pool(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1758,15 +1840,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
ubatch_slices: Optional[UBatchSlices] = None,
|
||||
num_tokens_after_padding: Optional[torch.Tensor] = None,
|
||||
) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor,
|
||||
Optional[IntermediateTensors], dict[str, Any]]:
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
|
||||
num_input_tokens += num_pad
|
||||
if ubatch_slices:
|
||||
assert num_tokens_after_padding is not None
|
||||
num_input_tokens = int(num_tokens_after_padding[0].item() * 2)
|
||||
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
||||
elif ubatch_slices is None:
|
||||
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
|
||||
num_pad, num_tokens_after_padding = self.get_dp_padding(
|
||||
num_input_tokens)
|
||||
num_input_tokens += num_pad
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
@ -1821,7 +1910,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return (
|
||||
num_scheduled_tokens,
|
||||
num_input_tokens,
|
||||
num_tokens_across_dp,
|
||||
num_tokens_after_padding,
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
positions,
|
||||
@ -2027,7 +2116,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||
max_query_len) = self._prepare_inputs(scheduler_output)
|
||||
max_query_len, ubatch_slices, num_tokens_after_padding
|
||||
) = self._prepare_inputs(scheduler_output)
|
||||
|
||||
finally:
|
||||
if self.prepare_inputs_event is not None:
|
||||
@ -2042,7 +2132,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
model_kwargs,
|
||||
) = self._preprocess(scheduler_output, intermediate_tensors)
|
||||
) = self._preprocess(scheduler_output, intermediate_tensors,
|
||||
ubatch_slices, num_tokens_after_padding)
|
||||
|
||||
if ubatch_slices is not None:
|
||||
num_input_tokens = num_input_tokens // 2
|
||||
|
||||
uniform_decode = (max_query_len
|
||||
== self.uniform_decode_query_len) and (
|
||||
@ -2062,6 +2156,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
ubatch_slices=ubatch_slices,
|
||||
), record_function_or_nullcontext("Forward"),
|
||||
self.maybe_get_kv_connector_output(scheduler_output) as
|
||||
kv_connector_output):
|
||||
@ -2441,10 +2536,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# CudagraphWraper and CudagraphDispatcher of vllm.
|
||||
|
||||
# wrap the model with full cudagraph wrapper if needed.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \
|
||||
and not self.parallel_config.enable_dbo:
|
||||
self.model = CUDAGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
elif self.parallel_config.enable_dbo:
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.model = UBatchWrapper(self.model, self.vllm_config,
|
||||
CUDAGraphMode.FULL, self.device)
|
||||
else:
|
||||
self.model = UBatchWrapper(self.model, self.vllm_config,
|
||||
CUDAGraphMode.NONE, self.device)
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
assert getattr(self, "model", None) is not None, \
|
||||
@ -2642,6 +2745,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
allow_microbatching: bool = False,
|
||||
skip_eplb: bool = False,
|
||||
is_profile: bool = False,
|
||||
create_mixed_batch: bool = False,
|
||||
@ -2667,12 +2771,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
(1 token) and prefill (multiple tokens) requests.
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
"""
|
||||
ubatch_enabled = self.parallel_config.enable_dbo
|
||||
num_tokens_across_dp = None
|
||||
num_pad = 0
|
||||
should_ubatch = False
|
||||
if ubatch_enabled:
|
||||
should_ubatch = num_tokens >= \
|
||||
self.parallel_config.dbo_decode_token_threshold and \
|
||||
allow_microbatching
|
||||
|
||||
(should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch(
|
||||
num_tokens, num_tokens, should_ubatch, self.vllm_config)
|
||||
|
||||
# Currently the dummy run should only be ubatching during
|
||||
# cuda graph capture, meaning all DP ranks should already
|
||||
# have the same batch size
|
||||
if num_tokens_across_dp is not None:
|
||||
assert int(num_tokens_across_dp[0]) == num_tokens // 2
|
||||
|
||||
assert cudagraph_runtime_mode in {
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
}
|
||||
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||
if not should_ubatch:
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||
num_tokens += num_pad
|
||||
|
||||
# If cudagraph_mode.decode_mode() == FULL and
|
||||
@ -2690,6 +2812,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# for GQA/MQA.
|
||||
max_query_len = self.uniform_decode_query_len if uniform_decode else \
|
||||
num_tokens
|
||||
if allow_microbatching:
|
||||
assert self.uniform_decode_query_len == 1
|
||||
assert uniform_decode is True
|
||||
assert max_query_len == 1
|
||||
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
@ -2728,12 +2854,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
|
||||
attn_metadata: Optional[dict[str, Any]] = None
|
||||
ubatch_slices = None
|
||||
# We currently only microbatch if the number of tokens is
|
||||
# over a certain threshold.
|
||||
if should_ubatch:
|
||||
# We only support decode-only cudagraphs
|
||||
assert num_reqs == num_tokens
|
||||
assert num_tokens % 2 == 0
|
||||
ubatch_slices = [
|
||||
UBatchSlice(slice(0, num_reqs // 2), slice(0,
|
||||
num_tokens // 2)),
|
||||
UBatchSlice(slice(num_reqs // 2, num_reqs),
|
||||
slice(num_tokens // 2, num_tokens))
|
||||
]
|
||||
|
||||
attn_metadata: Optional[PerLayerAttnMetadata] = None
|
||||
|
||||
# If force_attention is True, we always capture attention. Otherwise,
|
||||
# it only happens for cudagraph_runtime_mode=FULL.
|
||||
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
attn_metadata = {}
|
||||
if ubatch_slices is not None:
|
||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||
|
||||
if create_mixed_batch:
|
||||
# In the mixed batch mode (used for FI warmup), we use
|
||||
@ -2766,12 +2908,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
slot_mapping=self.input_batch.
|
||||
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
|
||||
causal=True)
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
attn_metadata_i = attn_group.metadata_builder\
|
||||
.build_for_cudagraph_capture(common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
if ubatch_slices is not None:
|
||||
common_attn_metadata_list = split_attn_metadata(
|
||||
ubatch_slices, common_attn_metadata)
|
||||
for ubid, common_attn_metadata in enumerate(
|
||||
common_attn_metadata_list):
|
||||
assert common_attn_metadata.max_query_len == 1
|
||||
attn_metadata_i = (attn_group\
|
||||
.get_metadata_builder(ubatch_id=ubid)\
|
||||
.build_for_cudagraph_capture(common_attn_metadata))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
attn_metadata[ubid][
|
||||
layer_name] = attn_metadata_i
|
||||
else:
|
||||
assert type(attn_metadata) is dict
|
||||
attn_metadata_i = attn_group.get_metadata_builder()\
|
||||
.build_for_cudagraph_capture(common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens, remove_lora):
|
||||
@ -2818,13 +2974,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
f"Cudagraph runtime mode mismatch at dummy_run. "
|
||||
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
|
||||
|
||||
if ubatch_slices is not None:
|
||||
num_tokens = num_tokens // 2
|
||||
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor):
|
||||
batch_descriptor=batch_descriptor,
|
||||
ubatch_slices=ubatch_slices):
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@ -3096,6 +3255,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
set_cudagraph_capturing_enabled(True)
|
||||
with freeze_gc(), graph_capture(device=self.device):
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
assert cudagraph_mode is not None
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
||||
|
||||
@ -3153,6 +3313,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
desc="Capturing CUDA graphs ({}, {})".format(
|
||||
"decode" if uniform_decode else "mixed prefill-decode",
|
||||
cudagraph_runtime_mode.name))
|
||||
enable_dbo = self.parallel_config.enable_dbo
|
||||
# DBO Only supports running Full cudagraphs with uniform
|
||||
# decode lengths
|
||||
if enable_dbo and uniform_decode:
|
||||
for num_tokens in compilation_cases:
|
||||
# If the number of tokens is greater than the microbatching
|
||||
# threshold, don't generate a microbatched cudagraph
|
||||
if (num_tokens
|
||||
< self.parallel_config.dbo_decode_token_threshold):
|
||||
continue
|
||||
|
||||
# Warmup
|
||||
for _ in range(
|
||||
self.compilation_config.cudagraph_num_of_warmups):
|
||||
force_attention = (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=True,
|
||||
allow_microbatching=True,
|
||||
skip_eplb=True)
|
||||
|
||||
# Graph Capture
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
uniform_decode=True,
|
||||
allow_microbatching=True,
|
||||
skip_eplb=True)
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens in compilation_cases:
|
||||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
||||
@ -3219,14 +3408,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
) -> list[AttentionGroup]:
|
||||
attn_groups: list[AttentionGroup] = []
|
||||
for attn_backend, layer_names in attn_backends_map.items():
|
||||
attn_metadata_builder_i = attn_backend.get_builder_cls()(
|
||||
attn_metadata_builders = []
|
||||
attn_metadata_builders.append(attn_backend.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
))
|
||||
if self.parallel_config.enable_dbo:
|
||||
attn_metadata_builders.append(
|
||||
attn_backend.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
))
|
||||
attn_group = AttentionGroup(attn_backend,
|
||||
attn_metadata_builder_i,
|
||||
attn_metadata_builders,
|
||||
layer_names)
|
||||
attn_groups.append(attn_group)
|
||||
return attn_groups
|
||||
@ -3246,11 +3444,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
min_cg_builder_name = None
|
||||
|
||||
for attn_group in self._attn_group_iterator():
|
||||
builder = attn_group.metadata_builder
|
||||
builder = attn_group.get_metadata_builder()
|
||||
if builder.cudagraph_support.value < min_cg_support.value:
|
||||
min_cg_support = builder.cudagraph_support
|
||||
min_cg_builder_name = builder.__class__.__name__
|
||||
|
||||
# Flexible resolve the cudagraph mode
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
# check cudagraph for mixed batch is supported
|
||||
@ -3316,7 +3513,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
is compatible (e.g., decode threshold is the same)
|
||||
"""
|
||||
for group in self._attn_group_iterator():
|
||||
attn_metadata_builder_i = group.metadata_builder
|
||||
attn_metadata_builder_i = group.get_metadata_builder()
|
||||
|
||||
# check that if any backends reorder batches; that the reordering
|
||||
# is compatible (e.g., decode threshold is the same)
|
||||
|
||||
303
vllm/v1/worker/gpu_ubatch_wrapper.py
Normal file
303
vllm/v1/worker/gpu_ubatch_wrapper.py
Normal file
@ -0,0 +1,303 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import threading
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import (create_forward_context, get_forward_context,
|
||||
override_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UbatchMetadata:
|
||||
context: UBatchContext
|
||||
input_ids: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
intermediate_tensors: Optional[IntermediateTensors]
|
||||
num_tokens: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphMetaData:
|
||||
cudagraph: torch.cuda.CUDAGraph
|
||||
ubatch_metadata: UbatchMetadata
|
||||
outputs: Optional[Any] = None
|
||||
|
||||
|
||||
class UBatchWrapper:
|
||||
|
||||
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode, device: torch.cuda.device):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.comm_stream = torch.cuda.Stream(device=device)
|
||||
# Two ubatch threads plus the main thread
|
||||
self.ready_barrier = threading.Barrier(3)
|
||||
|
||||
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
|
||||
|
||||
self.cudagraph_wrapper = None
|
||||
self.graph_pool = None
|
||||
if runtime_mode is not CUDAGraphMode.NONE:
|
||||
self.cudagraph_wrapper = CUDAGraphWrapper(
|
||||
runnable, vllm_config, runtime_mode=runtime_mode)
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||
"""
|
||||
Capture a cudagraph for a microbatched run.
|
||||
|
||||
The logic here is somewhat complicated because we need to make sure that
|
||||
each of the ubatch threads initialize the cuda context before we start
|
||||
the graph capture.
|
||||
|
||||
The flow is as follows:
|
||||
1. The main thread starts up each ubatch thread. Each thread will
|
||||
initialize its cuda context (torch.cuda.current_blas_handle())
|
||||
before going to sleep upon entering the ubatch_context.
|
||||
|
||||
2. The main thread starts the graph capture and wakes up the first
|
||||
ubatch thread.
|
||||
|
||||
3. Each ubatch thread runs the model to completion and returns the
|
||||
completed output tensors back to the main thread.
|
||||
|
||||
4. The main thread stores the captured cudagraph along with its metadata
|
||||
and returns
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def _capture_ubatch_thread(results, ubatch_metadata):
|
||||
ubatch_context = ubatch_metadata.context
|
||||
with torch.cuda.stream(ubatch_context.compute_stream):
|
||||
_ = torch.cuda.current_blas_handle()
|
||||
with torch.cuda.stream(ubatch_context.comm_stream):
|
||||
_ = torch.cuda.current_blas_handle()
|
||||
with ubatch_context:
|
||||
model_output = model(
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
)
|
||||
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
compute_stream = ubatch_metadata[0].context.compute_stream
|
||||
num_tokens = ubatch_metadata[0].num_tokens + \
|
||||
ubatch_metadata[1].num_tokens
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(target=_capture_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||
|
||||
# Capture the cudagraph
|
||||
cudagraph_metadata = \
|
||||
CUDAGraphMetaData(
|
||||
cudagraph=torch.cuda.CUDAGraph(),
|
||||
ubatch_metadata=ubatch_metadata,
|
||||
)
|
||||
with torch.cuda.graph(cudagraph_metadata.cudagraph,
|
||||
stream=compute_stream,
|
||||
pool=self.graph_pool):
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
cudagraph_metadata.outputs = result
|
||||
self.cudagraphs[num_tokens] = cudagraph_metadata
|
||||
return cudagraph_metadata.outputs
|
||||
|
||||
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(results, model, ubatch_metadata):
|
||||
with ubatch_metadata.context:
|
||||
model_output = model(
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
)
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
|
||||
# Ubatch threads will manually manage the forward context, so we
|
||||
# override it to None here so we can have it restored correctly
|
||||
# after both threads have finished
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
model,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
return result
|
||||
|
||||
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
|
||||
positions, inputs_embeds, intermediate_tensors,
|
||||
compute_stream, dp_metadata, batch_descriptor,
|
||||
cudagraph_runtime_mode) -> list[UbatchMetadata]:
|
||||
|
||||
# Create one forward context per ubatch
|
||||
forward_contexts = []
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
forward_contexts.append(
|
||||
create_forward_context(
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
self.vllm_config,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode))
|
||||
|
||||
ubatch_ctxs = make_ubatch_contexts(
|
||||
num_micro_batches=len(ubatch_slices),
|
||||
comm_stream=self.comm_stream,
|
||||
compute_stream=compute_stream,
|
||||
forward_contexts=forward_contexts,
|
||||
ready_barrier=self.ready_barrier)
|
||||
|
||||
ubatch_metadata: list[UbatchMetadata] = []
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
sliced_input_ids, sliced_positions, sliced_inputs_embeds, \
|
||||
sliced_intermediate_tensors = \
|
||||
self._slice_model_inputs(
|
||||
ubatch_slice.token_slice, input_ids, positions,
|
||||
inputs_embeds, intermediate_tensors)
|
||||
ubatch_metadata.append(
|
||||
UbatchMetadata(
|
||||
context=ubatch_ctxs[i],
|
||||
input_ids=sliced_input_ids,
|
||||
positions=sliced_positions,
|
||||
inputs_embeds=sliced_inputs_embeds,
|
||||
intermediate_tensors=sliced_intermediate_tensors,
|
||||
num_tokens=ubatch_slice.token_slice.stop -
|
||||
ubatch_slice.token_slice.start))
|
||||
|
||||
return ubatch_metadata
|
||||
|
||||
def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions,
|
||||
inputs_embeds, intermediate_tensors):
|
||||
sliced_input_ids = input_ids[tokens_slice]
|
||||
# if we are using mrope. Mrope adds an additional dimension to the
|
||||
# positions tensor
|
||||
if positions.ndim == 2:
|
||||
sliced_positions = positions[:, tokens_slice]
|
||||
else:
|
||||
sliced_positions = positions[tokens_slice]
|
||||
sliced_inputs_embeds = inputs_embeds[
|
||||
tokens_slice] if inputs_embeds else None
|
||||
sliced_intermediate_tensors = intermediate_tensors[
|
||||
tokens_slice] if intermediate_tensors else None
|
||||
|
||||
return (sliced_input_ids, sliced_positions, sliced_inputs_embeds,
|
||||
sliced_intermediate_tensors)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
ubatch_slices = forward_context.ubatch_slices
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
# If there's no ubatching, just run the runnable object
|
||||
if ubatch_slices is None:
|
||||
if cudagraph_runtime_mode in (CUDAGraphMode.NONE,
|
||||
CUDAGraphMode.PIECEWISE):
|
||||
return self.runnable(*args, **kwargs)
|
||||
else:
|
||||
assert self.cudagraph_wrapper is not None
|
||||
return self.cudagraph_wrapper(*args, **kwargs)
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
num_tokens = (ubatch_slices[0].token_slice.stop -
|
||||
ubatch_slices[0].token_slice.start) * 2
|
||||
input_ids = kwargs['input_ids']
|
||||
positions = kwargs['positions']
|
||||
intermediate_tensors = kwargs['intermediate_tensors']
|
||||
inputs_embeds = kwargs['inputs_embeds']
|
||||
compute_stream = torch.cuda.current_stream()
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
|
||||
# We shouldn't be here unless we are running with multiple DP ranks
|
||||
assert dp_metadata is not None
|
||||
|
||||
if num_tokens not in self.cudagraphs \
|
||||
and cudagraph_runtime_mode is CUDAGraphMode.FULL:
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
elif num_tokens in self.cudagraphs:
|
||||
cudagraph_metadata = self.cudagraphs[num_tokens]
|
||||
cudagraph_metadata.cudagraph.replay()
|
||||
return cudagraph_metadata.outputs
|
||||
else:
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
155
vllm/v1/worker/ubatch_splitting.py
Normal file
155
vllm/v1/worker/ubatch_splitting.py
Normal file
@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import DPMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import round_up
|
||||
from vllm.v1.worker.ubatch_utils import (UBatchSlice, UBatchSlices,
|
||||
is_second_ubatch_empty)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def should_ubatch_with_num_tokens(
|
||||
should_ubatch: bool,
|
||||
orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
return DPMetadata.should_ubatch_across_dp(should_ubatch,
|
||||
orig_num_tokens_per_ubatch,
|
||||
padded_num_tokens_per_ubatch,
|
||||
dp_size, dp_rank)
|
||||
|
||||
|
||||
def get_dp_padding_ubatch(
|
||||
num_tokens_unpadded: int, num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
vllm_config: VllmConfig) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
"""
|
||||
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||
run with microbatching or none of them do. If this function decides
|
||||
not to run with microbatching. It will "abort" meaning that no padding
|
||||
information will be returned to the caller. It will return (False, None)
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that the run with the same number
|
||||
of tokens
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if should_ubatch if False
|
||||
]
|
||||
|
||||
"""
|
||||
assert num_tokens_padded >= num_tokens_unpadded
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
if dp_size == 1:
|
||||
# Early exit.
|
||||
return False, None
|
||||
|
||||
# If this DP rank doesn't want to attempt microbatching
|
||||
if not should_attempt_ubatching:
|
||||
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
|
||||
False, 0, 0, vllm_config)
|
||||
assert should_ubatch is False
|
||||
assert num_tokens_across_dp is None
|
||||
return should_ubatch, num_tokens_across_dp
|
||||
|
||||
# Round up to the next multiple of two for even divisibility
|
||||
num_tokens_padded = round_up(num_tokens_padded, 2)
|
||||
num_tokens_per_ubatch = num_tokens_padded // 2
|
||||
should_ubatch = True
|
||||
|
||||
# Sanity Check that the existing padding isn't giving us an empty second
|
||||
# ubatch. Abort if so
|
||||
if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded):
|
||||
logger.debug(
|
||||
"Empty second µbatch detected: unpadded tokens: %s, padded "
|
||||
"tokens: %s", num_tokens_unpadded, num_tokens_padded)
|
||||
should_ubatch = False
|
||||
|
||||
# Note that we compute the number of padded tokens per ubatch
|
||||
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
|
||||
should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch,
|
||||
vllm_config)
|
||||
if not should_ubatch:
|
||||
assert num_tokens_across_dp is None
|
||||
return should_ubatch, num_tokens_across_dp
|
||||
|
||||
assert num_tokens_across_dp is not None
|
||||
|
||||
max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item())
|
||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
|
||||
dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return should_ubatch, num_tokens_after_padding
|
||||
|
||||
|
||||
def ubatch_split(
|
||||
max_num_scheduled_tokens: int,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Coordinates amongst all DP ranks to determine if and how the full batch
|
||||
should be split into microbatches.
|
||||
|
||||
Returns: tuple[
|
||||
ubatch_slices: if this is set then all DP ranks have agreed to
|
||||
microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if ubatch_slices is None
|
||||
]
|
||||
|
||||
"""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
# Don't bother with the should_ubatch handshaking unless microbatching
|
||||
# is enabled
|
||||
if not parallel_config.enable_dbo:
|
||||
return (None, None)
|
||||
|
||||
# Check preconditions for microbatching
|
||||
should_attempt_ubatching = \
|
||||
parallel_config.enable_dbo and \
|
||||
num_tokens_unpadded >= \
|
||||
parallel_config.dbo_decode_token_threshold \
|
||||
and max_num_scheduled_tokens == 1
|
||||
|
||||
# Don't microbatch unless every other DP worker is also microbatching
|
||||
num_tokens_after_padding = None
|
||||
(should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch(
|
||||
num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching,
|
||||
vllm_config)
|
||||
if not should_ubatch:
|
||||
return (None, None)
|
||||
|
||||
# This doesn't actually pad the ubatch slices. It just initializes the
|
||||
# split point to the padded value so that padding can be applied
|
||||
# to the second ubatch in pad_out_ubatch_slice after attention
|
||||
# metadata creation
|
||||
assert num_tokens_after_padding is not None
|
||||
total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item())
|
||||
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
|
||||
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
|
||||
num_tokens_unpadded)
|
||||
|
||||
# Note there's an assumption here that there's 1 token per request
|
||||
ubatch_slices = [
|
||||
UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice),
|
||||
UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice)
|
||||
]
|
||||
|
||||
return (ubatch_slices, num_tokens_after_padding)
|
||||
19
vllm/v1/worker/ubatch_utils.py
Normal file
19
vllm/v1/worker/ubatch_utils.py
Normal file
@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
|
||||
@dataclass
|
||||
class UBatchSlice:
|
||||
request_slice: slice
|
||||
token_slice: slice
|
||||
|
||||
|
||||
UBatchSlices: TypeAlias = list[UBatchSlice]
|
||||
|
||||
|
||||
def is_second_ubatch_empty(orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int) -> bool:
|
||||
return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch
|
||||
211
vllm/v1/worker/ubatching.py
Normal file
211
vllm/v1/worker/ubatching.py
Normal file
@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import forward_context
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import current_stream
|
||||
|
||||
_THREAD_ID_TO_CONTEXT: dict = {}
|
||||
_CURRENT_CONTEXTS: list[Optional['UBatchContext']] = [None, None]
|
||||
|
||||
|
||||
class UBatchContext:
|
||||
"""
|
||||
Context manager for micro-batching synchronization using threading events.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
id: int,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
forward_context: ForwardContext,
|
||||
ready_barrier: threading.Barrier,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.cuda.Event,
|
||||
gpu_compute_done_event: torch.cuda.Event,
|
||||
schedule: str = "default"):
|
||||
self.id = id
|
||||
self.comm_stream = comm_stream
|
||||
self.compute_stream = compute_stream
|
||||
self.forward_context = forward_context
|
||||
self.ready_barrier = ready_barrier
|
||||
self.cpu_wait_event = cpu_wait_event
|
||||
self.cpu_signal_event = cpu_signal_event
|
||||
self.current_stream = compute_stream
|
||||
self.gpu_comm_done_event = gpu_comm_done_event
|
||||
self.gpu_compute_done_event = gpu_compute_done_event
|
||||
self.schedule = schedule
|
||||
self.recv_hook = None
|
||||
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
|
||||
_CURRENT_CONTEXTS[self.id] = self
|
||||
self.ready_barrier.wait()
|
||||
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
# Assume we start on the compute stream
|
||||
assert current_stream() == self.compute_stream
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||
_CURRENT_CONTEXTS[self.id] = None
|
||||
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
self.maybe_run_recv_hook()
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.clear()
|
||||
self.current_stream = self.compute_stream
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
return False
|
||||
|
||||
def _restore_context(self):
|
||||
forward_context._forward_context = self.forward_context
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
|
||||
def update_stream(self, stream):
|
||||
self.current_stream = stream
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
|
||||
def _signal_comm_done(self):
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
|
||||
def _signal_compute_done(self):
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
|
||||
def _wait_compute_done(self):
|
||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||
|
||||
def _wait_comm_done(self):
|
||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||
|
||||
def _cpu_yield(self):
|
||||
# It is critical for correctness that only one thread is running
|
||||
# at a time. These asserts just make sure that this is the only
|
||||
# thread running before waking the other one up and going to sleep
|
||||
assert forward_context._forward_context == self.forward_context
|
||||
assert current_stream() == self.current_stream
|
||||
assert not self.cpu_wait_event.is_set()
|
||||
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
|
||||
def switch_to_comm_sync(self):
|
||||
self._signal_compute_done()
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
def maybe_run_recv_hook(self):
|
||||
if self.recv_hook is not None:
|
||||
self.recv_hook()
|
||||
self.recv_hook = None
|
||||
|
||||
def yield_(self):
|
||||
self.current_stream = current_stream()
|
||||
self._cpu_yield()
|
||||
if self.current_stream != current_stream():
|
||||
self.update_stream(self.current_stream)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert current_stream() == self.compute_stream
|
||||
self._signal_compute_done()
|
||||
self._cpu_yield()
|
||||
assert self.current_stream == self.compute_stream
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
assert current_stream() == self.comm_stream
|
||||
self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
assert self.current_stream == self.comm_stream
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
|
||||
def dbo_enabled() -> bool:
|
||||
return len(_THREAD_ID_TO_CONTEXT) > 0
|
||||
|
||||
|
||||
def dbo_current_ubatch_id() -> int:
|
||||
if len(_THREAD_ID_TO_CONTEXT) == 0:
|
||||
return 0
|
||||
return _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
|
||||
|
||||
def _register_ubatch_function(func):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
ctx = _CURRENT_CONTEXTS[ctx_idx]
|
||||
func(ctx, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
|
||||
UBatchContext.yield_and_switch_from_compute_to_comm)
|
||||
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
|
||||
UBatchContext.yield_and_switch_from_comm_to_compute)
|
||||
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
|
||||
dbo_maybe_run_recv_hook = _register_ubatch_function(
|
||||
UBatchContext.maybe_run_recv_hook)
|
||||
dbo_switch_to_comm_sync = _register_ubatch_function(
|
||||
UBatchContext.switch_to_comm_sync)
|
||||
|
||||
|
||||
def dbo_register_recv_hook(recv_hook):
|
||||
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2]
|
||||
next_ctx.recv_hook = recv_hook
|
||||
|
||||
|
||||
def make_ubatch_contexts(
|
||||
num_micro_batches: int,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
forward_contexts: list[ForwardContext],
|
||||
ready_barrier: threading.Barrier,
|
||||
schedule: str = "default",
|
||||
) -> list[UBatchContext]:
|
||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||
"""
|
||||
Create a context manager for micro-batching synchronization.
|
||||
"""
|
||||
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
|
||||
gpu_comm_done_events = [
|
||||
torch.cuda.Event() for _ in range(num_micro_batches)
|
||||
]
|
||||
gpu_compute_done_events = [
|
||||
torch.cuda.Event() for _ in range(num_micro_batches)
|
||||
]
|
||||
|
||||
assert len(forward_contexts) == 2
|
||||
|
||||
ctxs = []
|
||||
for i in range(num_micro_batches):
|
||||
ctx = UBatchContext(id=i,
|
||||
compute_stream=compute_stream,
|
||||
comm_stream=comm_stream,
|
||||
forward_context=forward_contexts[i],
|
||||
ready_barrier=ready_barrier,
|
||||
cpu_wait_event=cpu_events[i],
|
||||
cpu_signal_event=cpu_events[(i + 1) %
|
||||
num_micro_batches],
|
||||
gpu_comm_done_event=gpu_comm_done_events[i],
|
||||
gpu_compute_done_event=gpu_compute_done_events[i],
|
||||
schedule=schedule)
|
||||
ctxs.append(ctx)
|
||||
|
||||
return ctxs
|
||||
@ -130,9 +130,17 @@ class MultiModalBudget:
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
metadata_builder: AttentionMetadataBuilder
|
||||
metadata_builders: list[AttentionMetadataBuilder]
|
||||
layer_names: list[str]
|
||||
|
||||
def get_metadata_builder(self,
|
||||
ubatch_id: Optional[int] = None
|
||||
) -> AttentionMetadataBuilder:
|
||||
if ubatch_id is None:
|
||||
return self.metadata_builders[0]
|
||||
assert len(self.metadata_builders) > ubatch_id
|
||||
return self.metadata_builders[ubatch_id]
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: MultiModalEmbeddings,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user