mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:25:32 +08:00
[Core/DBO][2/N] Dual-Batch Overlap add DeepEP High Throughput support and Prefill support (#24845)
Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
a903669e10
commit
cc1dc7ed6d
@ -5,11 +5,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
||||
from tests.v1.attention.utils import create_common_attn_metadata
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata)
|
||||
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -155,3 +156,83 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
assert results[1].num_reqs == mid_point
|
||||
assert results[1].num_actual_tokens == mid_point
|
||||
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
|
||||
[
|
||||
# Split in the middle of request 1
|
||||
([32, 40], [8, 8], 12, 2, 1),
|
||||
# Split inside the first request
|
||||
([32, 40], [8, 8], 4, 1, 2),
|
||||
],
|
||||
)
|
||||
def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point,
|
||||
expected_first_reqs,
|
||||
expected_second_reqs):
|
||||
"""Test splitting a prefill across ubatches"""
|
||||
import numpy as np
|
||||
|
||||
device = torch.device("cpu")
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens)
|
||||
common = create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
|
||||
num_scheduled_tokens = np.array(query_lens, dtype=np.int32)
|
||||
qsl_np = common.query_start_loc_cpu.numpy()
|
||||
num_tokens = common.num_actual_tokens
|
||||
|
||||
ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point)
|
||||
assert len(ubatch_slices) == 2
|
||||
|
||||
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
|
||||
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
|
||||
|
||||
# Token counts match the split
|
||||
assert first_meta.num_actual_tokens == split_point
|
||||
assert second_meta.num_actual_tokens == num_tokens - split_point
|
||||
|
||||
# Number of requests per ubatch
|
||||
assert first_meta.num_reqs == expected_first_reqs
|
||||
assert second_meta.num_reqs == expected_second_reqs
|
||||
|
||||
# Identify which request is split and how many tokens are in the first chunk
|
||||
split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1)
|
||||
tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx])
|
||||
orig_q_lens = (common.query_start_loc_cpu[1:] -
|
||||
common.query_start_loc_cpu[:-1])
|
||||
|
||||
# Check query length continuity: first-chunk + second-chunk == original qlen
|
||||
# First ubatch last request query length
|
||||
qlen_first_last = int(first_meta.query_start_loc_cpu[-1] -
|
||||
first_meta.query_start_loc_cpu[-2])
|
||||
# Second ubatch first request query length
|
||||
qlen_second_first = int(second_meta.query_start_loc_cpu[1] -
|
||||
second_meta.query_start_loc_cpu[0])
|
||||
assert qlen_first_last == tokens_in_first_chunk
|
||||
assert qlen_first_last + qlen_second_first == int(
|
||||
orig_q_lens[split_req_idx])
|
||||
|
||||
# Check seq_lens adjustments
|
||||
# Context lengths per original request
|
||||
context_lens = [s - q for s, q in zip(seq_lens, query_lens)]
|
||||
|
||||
# First ubatch: last request's seq_len should be
|
||||
# context + tokens_in_first_chunk
|
||||
expected_seqlen = context_lens[split_req_idx] + tokens_in_first_chunk
|
||||
assert int(first_meta.seq_lens[-1]) == expected_seqlen
|
||||
|
||||
# For full preceding requests in first ubatch, seq_lens should match
|
||||
# originals
|
||||
for i in range(first_meta.num_reqs - 1):
|
||||
assert int(first_meta.seq_lens[i]) == seq_lens[i]
|
||||
|
||||
# Second ubatch: first request (continuation) seq_len should be full
|
||||
# original
|
||||
assert int(second_meta.seq_lens[0]) == seq_lens[split_req_idx]
|
||||
# Any following full requests in second ubatch should match originals
|
||||
for j in range(1, second_meta.num_reqs):
|
||||
# Map to original request index
|
||||
orig_idx = split_req_idx + j
|
||||
assert int(second_meta.seq_lens[j]) == seq_lens[orig_idx]
|
||||
|
||||
@ -532,9 +532,8 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
# Mock runner for attention metadata building
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [
|
||||
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
|
||||
attn_metadata_builder
|
||||
]
|
||||
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
@ -659,9 +658,8 @@ def test_propose_tree(spec_token_tree):
|
||||
# Mock runner for attention metadata building.
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [
|
||||
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
|
||||
attn_metadata_builder
|
||||
]
|
||||
|
||||
# Setup inputs for the proposer.
|
||||
target_token_ids = torch.randint(0,
|
||||
|
||||
@ -638,11 +638,13 @@ class VllmConfig:
|
||||
|
||||
if self.parallel_config.enable_dbo:
|
||||
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
assert a2a_backend == "deepep_low_latency", \
|
||||
"Microbatching currently only supports the deepep_low_latency "\
|
||||
f"all2all backend. {a2a_backend} is not supported. To fix set "\
|
||||
"the VLLM_ALL2ALL_BACKEND environment variable to "\
|
||||
"deepep_low_latency and install the DeepEP kerenls."
|
||||
assert a2a_backend in \
|
||||
["deepep_low_latency", "deepep_high_throughput"], \
|
||||
"Microbatching currently only supports the deepep_low_latency and "\
|
||||
f"deepep_high_throughput all2all backend. {a2a_backend} is not "\
|
||||
"supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\
|
||||
"variable to deepep_low_latency or deepep_high_throughput and "\
|
||||
"install the DeepEP kernels."
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
@ -139,12 +139,18 @@ class ParallelConfig:
|
||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||
|
||||
enable_dbo: bool = False
|
||||
"""Enable microbatching for the model executor."""
|
||||
"""Enable dual batch overlap for the model executor."""
|
||||
|
||||
dbo_decode_token_threshold: int = 32
|
||||
"""The threshold for microbatching. If the number of tokens in the
|
||||
request is greater than this threshold, microbatching will be used.
|
||||
Otherwise, the request will be processed in a single batch."""
|
||||
"""The threshold for dual batch overlap for batches only containing decodes.
|
||||
If the number of tokens in the request is greater than this threshold,
|
||||
microbatching will be used. Otherwise, the request will be processed in a
|
||||
single batch."""
|
||||
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
|
||||
"""The threshold for dual batch overlap for batches that contain one or more
|
||||
prefills. If the number of tokens in the request is greater than this
|
||||
threshold, microbatching will be used. Otherwise, the request will be
|
||||
processed in a single batch."""
|
||||
|
||||
ray_workers_use_nsight: bool = False
|
||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
@ -200,12 +201,12 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = 1024 * 1024 * 1024
|
||||
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_rdma_bytes = None
|
||||
num_qps_per_rank = None
|
||||
|
||||
if self.internode:
|
||||
num_rdma_bytes = 1024 * 1024 * 1024
|
||||
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_qps_per_rank = self.num_sms // 2
|
||||
else:
|
||||
num_rdma_bytes = 0
|
||||
@ -230,13 +231,18 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
# It is dangerous to set num sms outside this function. num_sms is not
|
||||
# a part of the hash-key that identifies this object. If we are in a
|
||||
# situation where we make objects with different num_sms, the hash key
|
||||
# in get_or_create must be updated.
|
||||
handle.set_num_sms(self.num_sms)
|
||||
return handle
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
import deep_ep
|
||||
|
||||
# Right now the buffers are sized for only what the kernels were
|
||||
# created with. So we can only reduce the number of SMS used
|
||||
# but not increase it.
|
||||
if num_sms > self.num_sms:
|
||||
num_sms = self.num_sms
|
||||
deep_ep.Buffer.set_num_sms(num_sms)
|
||||
|
||||
|
||||
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
@ -265,7 +271,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
import deep_ep
|
||||
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = 1024 * 1024 * 1024
|
||||
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_qps_per_rank = num_local_experts
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
@ -291,3 +297,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
return handle
|
||||
|
||||
# DeepEP LL uses RDMA so no SMs are used for communication
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
return 0
|
||||
@ -60,6 +60,12 @@ class All2AllManagerBase:
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
pass
|
||||
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
return None # None means it could use the whole GPU
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -330,6 +330,8 @@ class EngineArgs:
|
||||
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||
dbo_decode_token_threshold: int = \
|
||||
ParallelConfig.dbo_decode_token_threshold
|
||||
dbo_prefill_token_threshold: int = \
|
||||
ParallelConfig.dbo_prefill_token_threshold
|
||||
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
|
||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||
expert_placement_strategy: ExpertPlacementStrategy = \
|
||||
@ -698,6 +700,9 @@ class EngineArgs:
|
||||
parallel_group.add_argument(
|
||||
"--dbo-decode-token-threshold",
|
||||
**parallel_kwargs["dbo_decode_token_threshold"])
|
||||
parallel_group.add_argument(
|
||||
"--dbo-prefill-token-threshold",
|
||||
**parallel_kwargs["dbo_prefill_token_threshold"])
|
||||
parallel_group.add_argument("--enable-eplb",
|
||||
**parallel_kwargs["enable_eplb"])
|
||||
parallel_group.add_argument("--eplb-config",
|
||||
@ -1316,6 +1321,7 @@ class EngineArgs:
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
enable_dbo=self.enable_dbo,
|
||||
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
||||
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
|
||||
enable_eplb=self.enable_eplb,
|
||||
eplb_config=self.eplb_config,
|
||||
expert_placement_strategy=self.expert_placement_strategy,
|
||||
|
||||
11
vllm/envs.py
11
vllm/envs.py
@ -189,6 +189,8 @@ if TYPE_CHECKING:
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
||||
VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024
|
||||
VLLM_DBO_COMM_SMS: int = 20
|
||||
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
||||
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
||||
|
||||
@ -1392,6 +1394,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
|
||||
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
|
||||
|
||||
# The size in MB of the buffers (NVL and RDMA) used by DeepEP
|
||||
"VLLM_DEEPEP_BUFFER_SIZE_MB":
|
||||
lambda: int(os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024")),
|
||||
|
||||
# The number of SMs to allocate for communication kernels when running DBO
|
||||
# the rest of the SMs on the device will be allocated to compute
|
||||
"VLLM_DBO_COMM_SMS":
|
||||
lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")),
|
||||
|
||||
# Valid values are container,code_interpreter,web_search_preview
|
||||
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
|
||||
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":
|
||||
|
||||
@ -12,6 +12,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.utils import round_up
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_current_ubatch_id, dbo_enabled, dbo_switch_to_comm,
|
||||
dbo_switch_to_compute, dbo_switch_to_compute_sync,
|
||||
dbo_yield_and_switch_from_comm_to_compute,
|
||||
dbo_yield_and_switch_from_compute_to_comm)
|
||||
|
||||
|
||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
@ -46,9 +51,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self.async_prepare = True
|
||||
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
self.handle = None
|
||||
# requires. Under DBO microbatching we must track one handle per
|
||||
# micro-batch to avoid races between threads.
|
||||
self.handles = [None, None]
|
||||
|
||||
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
|
||||
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
|
||||
@ -89,6 +94,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
has_scales = token_scales is not None
|
||||
|
||||
# We yield before launching the dispatch kernel since the dispatch
|
||||
# kernel will block the CPU so we want to queue up all the compute
|
||||
# for the other ubatch before the dispatch kernel starts.
|
||||
dbo_yield_and_switch_from_compute_to_comm()
|
||||
|
||||
(num_tokens_per_rank, num_tokens_per_rdma_rank,
|
||||
dispatch_expert_num_tokens, is_token_in_rank,
|
||||
event) = self.buffer.get_dispatch_layout(
|
||||
@ -104,7 +114,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
(
|
||||
token_data, expert_topk_ids, expert_topk_weights,
|
||||
expert_num_tokens_per_expert_list, self.handle, event
|
||||
expert_num_tokens_per_expert_list, handle, event
|
||||
) = self.buffer.dispatch(
|
||||
x=token_data,
|
||||
handle=None,
|
||||
@ -119,9 +129,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_alignment=1,
|
||||
config=self._get_dispatch_config(),
|
||||
previous_event=None,
|
||||
async_finish=self.async_prepare,
|
||||
async_finish=self.async_prepare and not dbo_enabled(),
|
||||
allocate_on_comm_stream=False)
|
||||
|
||||
# record the handle for this ubatch
|
||||
a2a_idx = dbo_current_ubatch_id()
|
||||
self.handles[a2a_idx] = handle
|
||||
|
||||
dbo_switch_to_compute_sync()
|
||||
|
||||
return lambda: self._receiver(
|
||||
event,
|
||||
has_scales,
|
||||
@ -146,7 +162,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
if self.async_prepare:
|
||||
if event.event is not None:
|
||||
event.current_stream_wait()
|
||||
|
||||
if has_scales:
|
||||
@ -207,7 +223,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
) -> mk.ReceiverType:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
@ -233,14 +249,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1q_scale = None
|
||||
a1_post_scale = quant_config.a1_scale
|
||||
|
||||
return (lambda *args: None,
|
||||
self._do_dispatch(tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts,
|
||||
a1_scale=a1_post_scale,
|
||||
quant_config=quant_config))
|
||||
return self._do_dispatch(tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts,
|
||||
a1_scale=a1_post_scale,
|
||||
quant_config=quant_config)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@ -252,10 +267,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
(_, receiver) = self.prepare_async(a1, topk_weights, topk_ids,
|
||||
num_experts, expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
receiver = self.prepare_async(a1, topk_weights, topk_ids, num_experts,
|
||||
expert_map, apply_router_weight_on_input,
|
||||
quant_config)
|
||||
return receiver()
|
||||
|
||||
def _finalize(
|
||||
@ -269,7 +283,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
do_async: bool,
|
||||
) -> Optional[Callable]:
|
||||
|
||||
assert self.handle is not None
|
||||
a2a_idx = dbo_current_ubatch_id()
|
||||
handle = self.handles[a2a_idx]
|
||||
assert handle is not None
|
||||
|
||||
# fused_expert_output can have 0 tokens - This happens when none of the
|
||||
# tokens from the all2all reach this EP rank.
|
||||
@ -283,25 +299,35 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
dbo_yield_and_switch_from_compute_to_comm()
|
||||
combined_x, _, event = self.buffer.combine(
|
||||
x=fused_expert_output,
|
||||
handle=self.handle,
|
||||
handle=handle,
|
||||
topk_weights=None,
|
||||
config=self._get_combine_config(),
|
||||
previous_event=None,
|
||||
async_finish=do_async,
|
||||
async_finish=do_async and not dbo_enabled(),
|
||||
allocate_on_comm_stream=False)
|
||||
|
||||
dbo_switch_to_compute()
|
||||
|
||||
if do_async:
|
||||
|
||||
def _receiver():
|
||||
event.current_stream_wait()
|
||||
if event.event is not None:
|
||||
event.current_stream_wait()
|
||||
dbo_switch_to_comm()
|
||||
# Respect inplace outputs.
|
||||
output.copy_(combined_x, non_blocking=True)
|
||||
|
||||
return lambda: _receiver()
|
||||
# TODO(lucas): refactor the modular kernel so this will be
|
||||
# handled there
|
||||
dbo_yield_and_switch_from_comm_to_compute()
|
||||
|
||||
return _receiver
|
||||
else:
|
||||
# TODO(lucas): support this case with the refactored modular kernel
|
||||
assert not dbo_enabled()
|
||||
# Respect inplace outputs.
|
||||
output.copy_(combined_x, non_blocking=True)
|
||||
return None
|
||||
|
||||
@ -206,7 +206,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
do_async: bool,
|
||||
) -> Optional[Callable]:
|
||||
) -> tuple[Callable, Callable]:
|
||||
assert isinstance(
|
||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||
), ("Weight application and reduction happens in the combine kernel.")
|
||||
@ -233,7 +233,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return_recv_hook=do_recv_hook,
|
||||
out=output)
|
||||
|
||||
return recv_hook
|
||||
return recv_hook, lambda: None
|
||||
|
||||
def finalize_async(
|
||||
self,
|
||||
@ -243,8 +243,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> Callable:
|
||||
recv_hook = self._finalize(
|
||||
) -> tuple[Callable, Callable]:
|
||||
return self._finalize(
|
||||
output,
|
||||
fused_expert_output,
|
||||
topk_weights,
|
||||
@ -253,8 +253,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
weight_and_reduce_impl,
|
||||
do_async=True,
|
||||
)
|
||||
assert recv_hook is not None
|
||||
return recv_hook
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
|
||||
@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||
_resize_cache, count_expert_num_tokens)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
|
||||
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
dbo_register_recv_hook, dbo_yield)
|
||||
|
||||
#
|
||||
@ -223,7 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[Callable, ReceiverType]:
|
||||
) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel
|
||||
but do not wait for results from other workers.
|
||||
@ -239,10 +240,21 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
|
||||
Returns a callback that when invoked waits for results from other
|
||||
workers and has the same return signature as `prepare`, e.g.
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
results from other workers and has the same return signature as
|
||||
`prepare`, if a hook is returned this is more lightweight check that
|
||||
the recv is complete without doing extra work (used by DBO, will be
|
||||
refactored in the very near future)
|
||||
|
||||
e.g.
|
||||
|
||||
receiver = obj.prepare_async(...)
|
||||
ret = obj.prepare_async(...)
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
hook, receiver = ret
|
||||
hook()
|
||||
|
||||
if hook is not None:
|
||||
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
|
||||
|
||||
is equivalent to:
|
||||
@ -284,7 +296,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
) -> Callable:
|
||||
) -> Union[tuple[Callable, Callable], Callable]:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output but do not wait for results from other workers.
|
||||
@ -298,11 +310,17 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- weight_and_reduce_impl: An optional TopKWeightAndReduce
|
||||
implementation.
|
||||
|
||||
Returns a callback that when invoked waits for results from other
|
||||
workers and has the same return signature as `finalize`, e.g.
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
results from other workers and has the same return signature as
|
||||
`finalize`, if a hook is returned this is more lightweight check that
|
||||
the recv is complete without doing extra work (used by DBO, will be
|
||||
refactored in the very near future)
|
||||
|
||||
receiver = obj.finalize_async(output, ...)
|
||||
ret = obj.finalize_async(output, ...)
|
||||
... output not valid yet ...
|
||||
if isinstance(ret, tuple):
|
||||
hook, receiver = ret
|
||||
hook()
|
||||
receiver()
|
||||
... output valid here ...
|
||||
|
||||
@ -600,9 +618,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
layer due to any layer specific state that may be used by the component
|
||||
objects.
|
||||
"""
|
||||
fused_out_buffer = SharedResizableBuffer()
|
||||
workspace13_buffer = SharedResizableBuffer()
|
||||
workspace2_buffer = SharedResizableBuffer()
|
||||
|
||||
class SharedBuffers:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.fused_out = SharedResizableBuffer()
|
||||
self.workspace13 = SharedResizableBuffer()
|
||||
self.workspace2 = SharedResizableBuffer()
|
||||
|
||||
# Persistent buffers that are shared across `FusedMoEModularKernel`
|
||||
# instances (layers), to save memory and allocattions.
|
||||
#
|
||||
# We have two sets of buffers to support dual batch overlap (DBO) where each
|
||||
# microbatch (ubatch) should use its own set of buffers to avoid
|
||||
# cross-ubatch contimination.
|
||||
# NOTE that memory is lazily allocated for these buffers, meaning that if
|
||||
# DBO isn't being used, the second SharedBuffers will be empty.
|
||||
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -647,14 +679,18 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
|
||||
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = self.workspace13_buffer.get(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = self.workspace2_buffer.get(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace13 = buffers.workspace13.get(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = buffers.workspace2.get(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||
@ -733,9 +769,11 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
fused_out = self.fused_out_buffer.get(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
fused_out = buffers.fused_out.get(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int
|
||||
@ -868,6 +906,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if not self.prepare_finalize.supports_async():
|
||||
# We shouldn't be running an a2a kernel that doesn't
|
||||
# support async prepare/finalize
|
||||
# TODO(lucas): enable in follow-up
|
||||
assert not dbo_enabled()
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
@ -883,7 +922,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
dbo_maybe_run_recv_hook()
|
||||
hook, receiver = self.prepare_finalize.prepare_async(
|
||||
prepare_ret = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
@ -893,13 +932,21 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
# If DBO is being used, register the hook with the ubatch context
|
||||
# and call it in dbo_maybe_run_recv_hook instead of passing it to
|
||||
# the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
if not dbo_enabled():
|
||||
hook()
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = prepare_ret \
|
||||
if isinstance(prepare_ret, tuple) else (None, prepare_ret)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = receiver()
|
||||
@ -952,7 +999,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
else:
|
||||
recv_hook = self.prepare_finalize.finalize_async(
|
||||
finalize_ret = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
@ -964,11 +1011,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
assert recv_hook is not None
|
||||
dbo_register_recv_hook(recv_hook)
|
||||
dbo_yield()
|
||||
if not dbo_enabled():
|
||||
recv_hook()
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = finalize_ret \
|
||||
if isinstance(finalize_ret, tuple) else (None, finalize_ret)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
receiver()
|
||||
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
|
||||
@ -107,19 +107,57 @@ def _make_metadata_with_slice(
|
||||
the requests included in ubatch_slice
|
||||
"""
|
||||
|
||||
assert not ubatch_slice.is_empty(), (
|
||||
f"Ubatch slice {ubatch_slice} is empty")
|
||||
|
||||
request_slice = ubatch_slice.request_slice
|
||||
token_slice = ubatch_slice.token_slice
|
||||
|
||||
start_locs = attn_metadata.query_start_loc_cpu
|
||||
first_req = request_slice.start
|
||||
first_tok = token_slice.start
|
||||
last_req = request_slice.stop - 1
|
||||
last_tok = token_slice.stop - 1
|
||||
|
||||
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], \
|
||||
"Token slice start outside of first request"
|
||||
assert start_locs[last_req] <= last_tok < start_locs[last_req+1], \
|
||||
"Token slice end outside of last request"
|
||||
|
||||
# If the "middle" request has tokens in both ubatches, we have to split it.
|
||||
# If ubatch_slice is the first ubatch then we will be splitting the last
|
||||
# request. If it's the second microbatch, then we will be splitting the
|
||||
# first request
|
||||
splits_first_request = first_tok > start_locs[first_req]
|
||||
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
||||
|
||||
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
|
||||
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
|
||||
request_slice)
|
||||
|
||||
assert len(query_start_loc) >= 2, (
|
||||
f"query_start_loc must have at least 2 elements, "
|
||||
f"got {len(query_start_loc)}")
|
||||
query_start_loc_cpu = slice_query_start_locs(
|
||||
attn_metadata.query_start_loc_cpu, request_slice)
|
||||
|
||||
if splits_first_request:
|
||||
tokens_skipped = first_tok - start_locs[first_req]
|
||||
query_start_loc[1:] -= tokens_skipped
|
||||
query_start_loc_cpu[1:] -= tokens_skipped
|
||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||
|
||||
if splits_last_request:
|
||||
tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
|
||||
query_start_loc[-1] -= tokens_skipped
|
||||
query_start_loc_cpu[-1] -= tokens_skipped
|
||||
|
||||
# Make sure we don't modify the seq_lens tensors
|
||||
# (not cudagraph compatible)
|
||||
seq_lens = seq_lens.clone()
|
||||
seq_lens_cpu = seq_lens_cpu.clone()
|
||||
seq_lens[-1] -= tokens_skipped
|
||||
seq_lens_cpu[-1] -= tokens_skipped
|
||||
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
|
||||
request_slice]
|
||||
@ -167,6 +205,7 @@ def split_attn_metadata(
|
||||
for ubatch_slice in ubatch_slices:
|
||||
results.append(
|
||||
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ -696,7 +735,6 @@ def split_decodes_and_prefills(
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||
assert torch.all(query_lens[first_prefill:] > decode_threshold)
|
||||
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
||||
num_decodes = first_prefill
|
||||
num_prefills = num_reqs - num_decodes
|
||||
|
||||
@ -30,7 +30,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -192,9 +191,8 @@ class EagleProposer:
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
ubatch_id = dbo_current_ubatch_id()
|
||||
attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0)
|
||||
|
||||
@ -330,7 +328,7 @@ class EagleProposer:
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata = attn_metadata_builder\
|
||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1)
|
||||
@ -538,9 +536,8 @@ class EagleProposer:
|
||||
hidden_states: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[torch.Tensor]:
|
||||
ubatch_id = dbo_current_ubatch_id()
|
||||
tree_attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
assert isinstance(tree_attn_metadata_builder,
|
||||
TreeAttentionMetadataBuilder)
|
||||
|
||||
|
||||
@ -96,7 +96,8 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split
|
||||
from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds,
|
||||
ubatch_split)
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
|
||||
@ -1032,7 +1033,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
|
||||
num_tokens_unpadded)
|
||||
ubatch_slices, num_tokens_after_padding = \
|
||||
ubatch_split(max_num_scheduled_tokens,
|
||||
ubatch_split(num_scheduled_tokens,
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
self.vllm_config)
|
||||
@ -1206,7 +1207,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
ubatch_slices, common_attn_metadata)
|
||||
for ubid, common_attn_metadata in enumerate(
|
||||
common_attn_metadata_list):
|
||||
assert common_attn_metadata.max_query_len == 1
|
||||
attn_metadata_i = (attn_group.get_metadata_builder(
|
||||
ubatch_id=ubid).build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
@ -2182,9 +2182,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
) = self._preprocess(scheduler_output, intermediate_tensors,
|
||||
ubatch_slices, num_tokens_after_padding)
|
||||
|
||||
if ubatch_slices is not None:
|
||||
num_input_tokens = num_input_tokens // 2
|
||||
|
||||
uniform_decode = (max_query_len
|
||||
== self.uniform_decode_query_len) and (
|
||||
num_scheduled_tokens
|
||||
@ -2194,6 +2191,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cudagraph_runtime_mode, batch_descriptor = \
|
||||
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
||||
|
||||
# This is currently to get around the assert in the DPMetadata
|
||||
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
||||
if ubatch_slices is not None:
|
||||
num_input_tokens = ubatch_slices[0].num_tokens
|
||||
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with (set_forward_context(
|
||||
@ -2821,7 +2823,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
allow_microbatching: bool = False,
|
||||
allow_microbatching: bool = True,
|
||||
skip_eplb: bool = False,
|
||||
is_profile: bool = False,
|
||||
create_mixed_batch: bool = False,
|
||||
@ -2847,32 +2849,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
(1 token) and prefill (multiple tokens) requests.
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
"""
|
||||
ubatch_enabled = self.parallel_config.enable_dbo
|
||||
num_tokens_across_dp = None
|
||||
num_pad = 0
|
||||
should_ubatch = False
|
||||
if ubatch_enabled:
|
||||
should_ubatch = num_tokens >= \
|
||||
self.parallel_config.dbo_decode_token_threshold and \
|
||||
allow_microbatching
|
||||
|
||||
(should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch(
|
||||
num_tokens, num_tokens, should_ubatch, self.vllm_config)
|
||||
|
||||
# Currently the dummy run should only be ubatching during
|
||||
# cuda graph capture, meaning all DP ranks should already
|
||||
# have the same batch size
|
||||
if num_tokens_across_dp is not None:
|
||||
assert int(num_tokens_across_dp[0]) == num_tokens // 2
|
||||
|
||||
assert cudagraph_runtime_mode in {
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
}
|
||||
|
||||
if not should_ubatch:
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||
num_tokens += num_pad
|
||||
|
||||
# If cudagraph_mode.decode_mode() == FULL and
|
||||
# cudagraph_mode.separate_routine(). This means that we are using
|
||||
# different graphs and/or modes for mixed prefill-decode batches vs.
|
||||
@ -2888,10 +2868,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# for GQA/MQA.
|
||||
max_query_len = self.uniform_decode_query_len if uniform_decode else \
|
||||
num_tokens
|
||||
if allow_microbatching:
|
||||
assert self.uniform_decode_query_len == 1
|
||||
assert uniform_decode is True
|
||||
assert max_query_len == 1
|
||||
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
@ -2930,20 +2906,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
|
||||
|
||||
ubatch_slices = None
|
||||
num_tokens_after_padding = None
|
||||
|
||||
# We currently only microbatch if the number of tokens is
|
||||
# over a certain threshold.
|
||||
if should_ubatch:
|
||||
# We only support decode-only cudagraphs
|
||||
assert num_reqs == num_tokens
|
||||
assert num_tokens % 2 == 0
|
||||
ubatch_slices = [
|
||||
UBatchSlice(slice(0, num_reqs // 2), slice(0,
|
||||
num_tokens // 2)),
|
||||
UBatchSlice(slice(num_reqs // 2, num_reqs),
|
||||
slice(num_tokens // 2, num_tokens))
|
||||
]
|
||||
if self.parallel_config.enable_dbo and allow_microbatching:
|
||||
ubatch_slices, num_tokens_after_padding = ubatch_split(
|
||||
num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
# If we failed to microbatch, currently need to resynchronize
|
||||
# TODO(lucas,sage): we should be able to avoid this second sync by
|
||||
# refactoring `get_dp_padding_ubatch` and `get_dp_padding` into
|
||||
# a single `coordinate_batch_across_dp` function.
|
||||
if num_tokens_after_padding is None:
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||
num_tokens_after_padding = num_tokens + num_pad
|
||||
else:
|
||||
num_tokens_across_dp = num_tokens_after_padding
|
||||
num_tokens_after_padding = int(num_tokens_after_padding[0].item())
|
||||
|
||||
attn_metadata: Optional[PerLayerAttnMetadata] = None
|
||||
|
||||
@ -2966,6 +2953,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.seq_lens.np[num_reqs:] = 0
|
||||
self.seq_lens.copy_to_gpu()
|
||||
|
||||
cum_num_tokens, _ = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
@ -3060,7 +3052,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens=num_tokens_after_padding,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
@ -3395,56 +3387,51 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
desc="Capturing CUDA graphs ({}, {})".format(
|
||||
"decode" if uniform_decode else "mixed prefill-decode",
|
||||
cudagraph_runtime_mode.name))
|
||||
enable_dbo = self.parallel_config.enable_dbo
|
||||
# DBO Only supports running Full cudagraphs with uniform
|
||||
# decode lengths
|
||||
if enable_dbo and uniform_decode:
|
||||
for num_tokens in compilation_cases:
|
||||
# If the number of tokens is greater than the microbatching
|
||||
# threshold, don't generate a microbatched cudagraph
|
||||
if (num_tokens
|
||||
< self.parallel_config.dbo_decode_token_threshold):
|
||||
continue
|
||||
|
||||
# Warmup
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens in compilation_cases:
|
||||
# We currently only capture ubatched graphs when its a FULL
|
||||
# cudagraph and for uniform decode batches.
|
||||
capture_ubatched_graph = self.parallel_config.enable_dbo \
|
||||
and cudagraph_runtime_mode == CUDAGraphMode.FULL \
|
||||
and uniform_decode \
|
||||
and check_ubatch_thresholds(
|
||||
config=self.vllm_config.parallel_config,
|
||||
num_tokens=num_tokens,
|
||||
uniform_decode=uniform_decode,
|
||||
)
|
||||
|
||||
# Currently we capture both microbatched and non-microbatched
|
||||
# graphs when capture_ubatched_graph is True, this is because
|
||||
# occasionally we will be forced out of microbatching due to other
|
||||
# DP ranks not microbatching (usually caused by an empty second
|
||||
# microbatch; once we resolve this, we can remove the
|
||||
# non-microbatched graph capture).
|
||||
allow_microbatching_options = [True, False] if \
|
||||
capture_ubatched_graph else [False]
|
||||
for allow_microbatching in allow_microbatching_options:
|
||||
for _ in range(
|
||||
self.compilation_config.cudagraph_num_of_warmups):
|
||||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
||||
# But be careful, warm up with `NONE`is orthogonal to
|
||||
# if we want to warm up attention or not. This is
|
||||
# different from the case where `FULL` implies capture
|
||||
# attention while `PIECEWISE` implies no attention.
|
||||
force_attention = (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=True,
|
||||
allow_microbatching=True,
|
||||
skip_eplb=True)
|
||||
|
||||
# Graph Capture
|
||||
uniform_decode=uniform_decode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
uniform_decode=True,
|
||||
allow_microbatching=True,
|
||||
skip_eplb=True)
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens in compilation_cases:
|
||||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
||||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
||||
# But be careful, warm up with `NONE`is orthogonal to
|
||||
# if we want to warm up attention or not. This is
|
||||
# different from the case where `FULL` implies capture
|
||||
# attention while `PIECEWISE` implies no attention.
|
||||
force_attention = (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=uniform_decode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=uniform_decode,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self.maybe_remove_all_loras(self.lora_config)
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
@ -3500,24 +3487,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_groups: list[AttentionGroup] = []
|
||||
for (attn_backend,
|
||||
kv_cache_spec), layer_names in attn_backends_map.items():
|
||||
attn_metadata_builders = []
|
||||
attn_metadata_builders.append(attn_backend.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
attn_group = AttentionGroup.create_with_metadata_builders(
|
||||
attn_backend,
|
||||
layer_names,
|
||||
kv_cache_spec,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
))
|
||||
if self.parallel_config.enable_dbo:
|
||||
attn_metadata_builders.append(
|
||||
attn_backend.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
))
|
||||
attn_group = AttentionGroup(attn_backend,
|
||||
attn_metadata_builders,
|
||||
layer_names, kv_cache_spec)
|
||||
num_metadata_builders=1
|
||||
if not self.parallel_config.enable_dbo else 2,
|
||||
)
|
||||
|
||||
attn_groups.append(attn_group)
|
||||
return attn_groups
|
||||
|
||||
|
||||
@ -1,25 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import (create_forward_context, get_forward_context,
|
||||
override_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclass
|
||||
class UbatchMetadata:
|
||||
context: UBatchContext
|
||||
input_ids: torch.Tensor
|
||||
@ -29,13 +32,55 @@ class UbatchMetadata:
|
||||
num_tokens: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclass
|
||||
class CUDAGraphMetaData:
|
||||
cudagraph: torch.cuda.CUDAGraph
|
||||
ubatch_metadata: UbatchMetadata
|
||||
outputs: Optional[Any] = None
|
||||
|
||||
|
||||
class SMControlContextManager:
|
||||
|
||||
def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None],
|
||||
set_compute_sms: Callable[[int], None]):
|
||||
"""
|
||||
Context manager for controlling SM (Streaming Multiprocessor)
|
||||
allocation. Upon entering the context, it sets the number of SMs
|
||||
allocated for communication and computation to comm_sms and
|
||||
total_sms - comm_sms respectively. Upon exiting, it restores the
|
||||
allocation to use all available SMs (i.e. total_sms).
|
||||
|
||||
Args:
|
||||
comm_sms (int): The number of SMs to allocate for communication.
|
||||
(The remainder will be used for computation.)
|
||||
set_comm_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for communication.
|
||||
set_compute_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for computation.
|
||||
"""
|
||||
|
||||
assert current_platform.is_cuda(), \
|
||||
"SM control is currently only supported on CUDA"
|
||||
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
total_sms = props.multi_processor_count
|
||||
|
||||
assert comm_sms < total_sms
|
||||
self.total_sms = total_sms
|
||||
self.compute_sms = total_sms - comm_sms
|
||||
self.comm_sms = comm_sms
|
||||
self.set_comm_sms = set_comm_sms
|
||||
self.set_compute_sms = set_compute_sms
|
||||
|
||||
def __enter__(self):
|
||||
self.set_comm_sms(self.comm_sms)
|
||||
self.set_compute_sms(self.compute_sms)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.set_comm_sms(self.total_sms)
|
||||
self.set_compute_sms(self.total_sms)
|
||||
|
||||
|
||||
class UBatchWrapper:
|
||||
|
||||
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||
@ -56,6 +101,35 @@ class UBatchWrapper:
|
||||
runnable, vllm_config, runtime_mode=runtime_mode)
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
self.sm_control = self._create_sm_control_context(vllm_config)
|
||||
|
||||
@staticmethod
|
||||
def _create_sm_control_context(vllm_config: VllmConfig):
|
||||
comm_sms = envs.VLLM_DBO_COMM_SMS
|
||||
|
||||
set_comm_sms = lambda sms: None
|
||||
if vllm_config.parallel_config.enable_expert_parallel:
|
||||
# Currently only DeepEP highthroughput supports SM control so this
|
||||
# only affects that case.
|
||||
all2all_manager = get_ep_group(
|
||||
).device_communicator.all2all_manager
|
||||
|
||||
if all2all_manager.max_sms_used() is not None:
|
||||
comm_sms = min(comm_sms, all2all_manager.max_sms_used())
|
||||
|
||||
if comm_sms > 0:
|
||||
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
|
||||
|
||||
# TODO(lucas): support other kernels besides DeepGEMM
|
||||
set_compute_sms = lambda sms: None
|
||||
if has_deep_gemm() and comm_sms > 0:
|
||||
import deep_gemm as dg
|
||||
set_compute_sms = lambda sms: dg.set_num_sms(sms)
|
||||
|
||||
return SMControlContextManager(comm_sms=comm_sms,
|
||||
set_comm_sms=set_comm_sms,
|
||||
set_compute_sms=set_compute_sms)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
@ -282,8 +356,8 @@ class UBatchWrapper:
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
with self.sm_control:
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
elif num_tokens in self.cudagraphs:
|
||||
cudagraph_metadata = self.cudagraphs[num_tokens]
|
||||
cudagraph_metadata.cudagraph.replay()
|
||||
@ -300,4 +374,5 @@ class UBatchWrapper:
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
with self.sm_control:
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
|
||||
@ -3,9 +3,10 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.forward_context import DPMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import round_up
|
||||
@ -29,6 +30,16 @@ def should_ubatch_with_num_tokens(
|
||||
dp_size, dp_rank)
|
||||
|
||||
|
||||
def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int,
|
||||
uniform_decode: bool) -> bool:
|
||||
if not config.enable_dbo:
|
||||
return False
|
||||
if uniform_decode:
|
||||
return num_tokens >= config.dbo_decode_token_threshold
|
||||
else:
|
||||
return num_tokens >= config.dbo_prefill_token_threshold
|
||||
|
||||
|
||||
def get_dp_padding_ubatch(
|
||||
num_tokens_unpadded: int, num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
@ -95,9 +106,37 @@ def get_dp_padding_ubatch(
|
||||
dtype=torch.int32)
|
||||
return should_ubatch, num_tokens_after_padding
|
||||
|
||||
def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \
|
||||
-> UBatchSlices:
|
||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
||||
|
||||
first_ubatch_token_slice = slice(0, split_point)
|
||||
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
|
||||
|
||||
# Determine request slices using exclusive stop semantics
|
||||
# First ubatch includes requests whose tokens overlap [0, split_point)
|
||||
first_ubatch_req_stop = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="left"))
|
||||
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
|
||||
|
||||
# Second ubatch starts at the request that contains the split_point
|
||||
# or the request starting exactly at split_point (if on boundary)
|
||||
second_ubatch_req_start = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="right") - 1)
|
||||
second_ubatch_req_slice = slice(second_ubatch_req_start,
|
||||
len(cu_num_tokens) - 1)
|
||||
|
||||
return [
|
||||
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
||||
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice)
|
||||
]
|
||||
|
||||
|
||||
def ubatch_split(
|
||||
max_num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_per_request: np.ndarray,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
vllm_config: VllmConfig,
|
||||
@ -122,17 +161,20 @@ def ubatch_split(
|
||||
return (None, None)
|
||||
|
||||
# Check preconditions for microbatching
|
||||
should_attempt_ubatching = \
|
||||
parallel_config.enable_dbo and \
|
||||
num_tokens_unpadded >= \
|
||||
parallel_config.dbo_decode_token_threshold \
|
||||
and max_num_scheduled_tokens == 1
|
||||
should_attempt_ubatching = check_ubatch_thresholds(
|
||||
parallel_config,
|
||||
num_tokens_unpadded,
|
||||
vllm_config,
|
||||
)
|
||||
|
||||
# Don't microbatch unless every other DP worker is also microbatching
|
||||
num_tokens_after_padding = None
|
||||
(should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch(
|
||||
num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching,
|
||||
vllm_config)
|
||||
should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
vllm_config,
|
||||
)
|
||||
|
||||
if not should_ubatch:
|
||||
return (None, None)
|
||||
|
||||
@ -141,15 +183,9 @@ def ubatch_split(
|
||||
# to the second ubatch in pad_out_ubatch_slice after attention
|
||||
# metadata creation
|
||||
assert num_tokens_after_padding is not None
|
||||
total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item())
|
||||
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
|
||||
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
|
||||
num_tokens_unpadded)
|
||||
token_split_point = int(num_tokens_after_padding[0].item())
|
||||
|
||||
# Note there's an assumption here that there's 1 token per request
|
||||
ubatch_slices = [
|
||||
UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice),
|
||||
UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice)
|
||||
]
|
||||
ubatch_slices = create_ubatch_slices(num_scheduled_tokens_per_request,
|
||||
token_split_point)
|
||||
|
||||
return (ubatch_slices, num_tokens_after_padding)
|
||||
|
||||
@ -10,6 +10,14 @@ class UBatchSlice:
|
||||
request_slice: slice
|
||||
token_slice: slice
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self.request_slice.start == self.request_slice.stop \
|
||||
or self.token_slice.start == self.token_slice.stop
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.token_slice.stop - self.token_slice.start
|
||||
|
||||
|
||||
UBatchSlices: TypeAlias = list[UBatchSlice]
|
||||
|
||||
|
||||
@ -51,8 +51,8 @@ class UBatchContext:
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
# Assume we start on the compute stream
|
||||
assert current_stream() == self.compute_stream
|
||||
# Assume we want to start on the compute stream
|
||||
self.update_stream(self.compute_stream)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
@ -62,17 +62,15 @@ class UBatchContext:
|
||||
self.maybe_run_recv_hook()
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.clear()
|
||||
self.current_stream = self.compute_stream
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
return False
|
||||
|
||||
def _restore_context(self):
|
||||
forward_context._forward_context = self.forward_context
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
|
||||
def update_stream(self, stream):
|
||||
self.current_stream = stream
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
if current_stream() != self.current_stream:
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
|
||||
def _signal_comm_done(self):
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
@ -99,9 +97,20 @@ class UBatchContext:
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
|
||||
def switch_to_comm(self):
|
||||
self.update_stream(self.comm_stream)
|
||||
|
||||
def switch_to_compute(self):
|
||||
self.update_stream(self.compute_stream)
|
||||
|
||||
def switch_to_comm_sync(self):
|
||||
self._signal_compute_done()
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def switch_to_compute_sync(self):
|
||||
self._signal_comm_done()
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
def maybe_run_recv_hook(self):
|
||||
@ -112,8 +121,7 @@ class UBatchContext:
|
||||
def yield_(self):
|
||||
self.current_stream = current_stream()
|
||||
self._cpu_yield()
|
||||
if self.current_stream != current_stream():
|
||||
self.update_stream(self.current_stream)
|
||||
self.update_stream(self.current_stream)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert current_stream() == self.compute_stream
|
||||
@ -153,15 +161,20 @@ def _register_ubatch_function(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
dbo_maybe_run_recv_hook = _register_ubatch_function(
|
||||
UBatchContext.maybe_run_recv_hook)
|
||||
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
|
||||
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
|
||||
UBatchContext.yield_and_switch_from_compute_to_comm)
|
||||
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
|
||||
UBatchContext.yield_and_switch_from_comm_to_compute)
|
||||
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
|
||||
dbo_maybe_run_recv_hook = _register_ubatch_function(
|
||||
UBatchContext.maybe_run_recv_hook)
|
||||
dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm)
|
||||
dbo_switch_to_compute = _register_ubatch_function(
|
||||
UBatchContext.switch_to_compute)
|
||||
dbo_switch_to_comm_sync = _register_ubatch_function(
|
||||
UBatchContext.switch_to_comm_sync)
|
||||
dbo_switch_to_compute_sync = _register_ubatch_function(
|
||||
UBatchContext.switch_to_compute_sync)
|
||||
|
||||
|
||||
def dbo_register_recv_hook(recv_hook):
|
||||
|
||||
@ -130,15 +130,32 @@ class MultiModalBudget:
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
# When ubatching is enabled we will have a metadata builder for each ubatch
|
||||
# so that if they use internal persistant buffers for cudagraphs, and they
|
||||
# won't have to worry about conflicting with the other ubatches.
|
||||
metadata_builders: list[AttentionMetadataBuilder]
|
||||
layer_names: list[str]
|
||||
kv_cache_spec: KVCacheSpec
|
||||
|
||||
@staticmethod
|
||||
def create_with_metadata_builders(
|
||||
backend: type[AttentionBackend],
|
||||
layer_names: list[str],
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
num_metadata_builders: int = 1,
|
||||
) -> 'AttentionGroup':
|
||||
metadata_builders = [
|
||||
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config,
|
||||
device)
|
||||
for _ in range(num_metadata_builders)
|
||||
]
|
||||
return AttentionGroup(backend, metadata_builders, layer_names,
|
||||
kv_cache_spec)
|
||||
|
||||
def get_metadata_builder(self,
|
||||
ubatch_id: Optional[int] = None
|
||||
) -> AttentionMetadataBuilder:
|
||||
if ubatch_id is None:
|
||||
return self.metadata_builders[0]
|
||||
ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||
assert len(self.metadata_builders) > ubatch_id
|
||||
return self.metadata_builders[ubatch_id]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user