diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index c74dbb3ebb17..7d7a46910be8 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -5,11 +5,12 @@ import pytest import torch from tests.v1.attention.test_attention_backends import BATCH_SPECS -from tests.v1.attention.utils import create_common_attn_metadata +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm.v1.attention.backends.utils import (UBatchSlice, _make_metadata_with_slice, slice_query_start_locs, split_attn_metadata) +from vllm.v1.worker.ubatch_utils import create_ubatch_slices @pytest.fixture @@ -155,3 +156,83 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): assert results[1].num_reqs == mid_point assert results[1].num_actual_tokens == mid_point assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) + + +@pytest.mark.parametrize( + "seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs", + [ + # Split in the middle of request 1 + ([32, 40], [8, 8], 12, 2, 1), + # Split inside the first request + ([32, 40], [8, 8], 4, 1, 2), + ], +) +def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point, + expected_first_reqs, + expected_second_reqs): + """Test splitting a prefill across ubatches""" + import numpy as np + + device = torch.device("cpu") + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens) + common = create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + num_scheduled_tokens = np.array(query_lens, dtype=np.int32) + qsl_np = common.query_start_loc_cpu.numpy() + num_tokens = common.num_actual_tokens + + ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point) + assert len(ubatch_slices) == 2 + + first_meta = _make_metadata_with_slice(ubatch_slices[0], common) + second_meta = _make_metadata_with_slice(ubatch_slices[1], common) + + # Token counts match the split + assert first_meta.num_actual_tokens == split_point + assert second_meta.num_actual_tokens == num_tokens - split_point + + # Number of requests per ubatch + assert first_meta.num_reqs == expected_first_reqs + assert second_meta.num_reqs == expected_second_reqs + + # Identify which request is split and how many tokens are in the first chunk + split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1) + tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx]) + orig_q_lens = (common.query_start_loc_cpu[1:] - + common.query_start_loc_cpu[:-1]) + + # Check query length continuity: first-chunk + second-chunk == original qlen + # First ubatch last request query length + qlen_first_last = int(first_meta.query_start_loc_cpu[-1] - + first_meta.query_start_loc_cpu[-2]) + # Second ubatch first request query length + qlen_second_first = int(second_meta.query_start_loc_cpu[1] - + second_meta.query_start_loc_cpu[0]) + assert qlen_first_last == tokens_in_first_chunk + assert qlen_first_last + qlen_second_first == int( + orig_q_lens[split_req_idx]) + + # Check seq_lens adjustments + # Context lengths per original request + context_lens = [s - q for s, q in zip(seq_lens, query_lens)] + + # First ubatch: last request's seq_len should be + # context + tokens_in_first_chunk + expected_seqlen = context_lens[split_req_idx] + tokens_in_first_chunk + assert int(first_meta.seq_lens[-1]) == expected_seqlen + + # For full preceding requests in first ubatch, seq_lens should match + # originals + for i in range(first_meta.num_reqs - 1): + assert int(first_meta.seq_lens[i]) == seq_lens[i] + + # Second ubatch: first request (continuation) seq_len should be full + # original + assert int(second_meta.seq_lens[0]) == seq_lens[split_req_idx] + # Any following full requests in second ubatch should match originals + for j in range(1, second_meta.num_reqs): + # Map to original request index + orig_idx = split_req_idx + j + assert int(second_meta.seq_lens[j]) == seq_lens[orig_idx] diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index e7f6b68fc3f7..23bfabfcf89b 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -532,9 +532,8 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): # Mock runner for attention metadata building proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builders = [ + proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ attn_metadata_builder - ] result = proposer.propose(target_token_ids=target_token_ids, target_positions=target_positions, @@ -659,9 +658,8 @@ def test_propose_tree(spec_token_tree): # Mock runner for attention metadata building. proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builders = [ + proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ attn_metadata_builder - ] # Setup inputs for the proposer. target_token_ids = torch.randint(0, diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 92fc68f8927c..a2562a10b45a 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -638,11 +638,13 @@ class VllmConfig: if self.parallel_config.enable_dbo: a2a_backend = envs.VLLM_ALL2ALL_BACKEND - assert a2a_backend == "deepep_low_latency", \ - "Microbatching currently only supports the deepep_low_latency "\ - f"all2all backend. {a2a_backend} is not supported. To fix set "\ - "the VLLM_ALL2ALL_BACKEND environment variable to "\ - "deepep_low_latency and install the DeepEP kerenls." + assert a2a_backend in \ + ["deepep_low_latency", "deepep_high_throughput"], \ + "Microbatching currently only supports the deepep_low_latency and "\ + f"deepep_high_throughput all2all backend. {a2a_backend} is not "\ + "supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\ + "variable to deepep_low_latency or deepep_high_throughput and "\ + "install the DeepEP kernels." if not self.instance_id: self.instance_id = random_uuid()[:5] diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index a84d88243016..f80eb1adc7fd 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -139,12 +139,18 @@ class ParallelConfig: """Disable the custom all-reduce kernel and fall back to NCCL.""" enable_dbo: bool = False - """Enable microbatching for the model executor.""" + """Enable dual batch overlap for the model executor.""" dbo_decode_token_threshold: int = 32 - """The threshold for microbatching. If the number of tokens in the - request is greater than this threshold, microbatching will be used. - Otherwise, the request will be processed in a single batch.""" + """The threshold for dual batch overlap for batches only containing decodes. + If the number of tokens in the request is greater than this threshold, + microbatching will be used. Otherwise, the request will be processed in a + single batch.""" + dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune + """The threshold for dual batch overlap for batches that contain one or more + prefills. If the number of tokens in the request is greater than this + threshold, microbatching will be used. Otherwise, the request will be + processed in a single batch.""" ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 149df73d8667..ae18429f6251 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any +from typing import Any, Optional import torch import torch.distributed as dist +import vllm.envs as envs from vllm.distributed import get_dp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger @@ -200,12 +201,12 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): def _make_all2all_kwargs(self) -> dict[Any, Any]: # Defaults for internode and intranode are taken from DeepEP tests. - num_nvl_bytes = 1024 * 1024 * 1024 + num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 num_rdma_bytes = None num_qps_per_rank = None if self.internode: - num_rdma_bytes = 1024 * 1024 * 1024 + num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 num_qps_per_rank = self.num_sms // 2 else: num_rdma_bytes = 0 @@ -230,13 +231,18 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): logger.debug("DeepEP all2all args %s", buffer_kwargs) handle: deep_ep.Buffer = self.handle_cache.get_or_create( buffer_kwargs, deep_ep.Buffer) - # It is dangerous to set num sms outside this function. num_sms is not - # a part of the hash-key that identifies this object. If we are in a - # situation where we make objects with different num_sms, the hash key - # in get_or_create must be updated. - handle.set_num_sms(self.num_sms) return handle + def set_num_sms(self, num_sms: int): + import deep_ep + + # Right now the buffers are sized for only what the kernels were + # created with. So we can only reduce the number of SMS used + # but not increase it. + if num_sms > self.num_sms: + num_sms = self.num_sms + deep_ep.Buffer.set_num_sms(num_sms) + class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): """ @@ -265,7 +271,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): import deep_ep # Defaults for internode and intranode are taken from DeepEP tests. - num_nvl_bytes = 1024 * 1024 * 1024 + num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 num_qps_per_rank = num_local_experts num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, @@ -291,3 +297,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): handle: deep_ep.Buffer = self.handle_cache.get_or_create( buffer_kwargs, deep_ep.Buffer) return handle + + # DeepEP LL uses RDMA so no SMs are used for communication + def max_sms_used(self) -> Optional[int]: + return 0 \ No newline at end of file diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 01f59b44a0e6..586441c91783 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -60,6 +60,12 @@ class All2AllManagerBase: # and reuse it for the same config. raise NotImplementedError + def set_num_sms(self, num_sms: int): + pass + + def max_sms_used(self) -> Optional[int]: + return None # None means it could use the whole GPU + def dispatch(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): raise NotImplementedError diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8c7a1b413cdb..556a490ffa10 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -330,6 +330,8 @@ class EngineArgs: enable_dbo: bool = ParallelConfig.enable_dbo dbo_decode_token_threshold: int = \ ParallelConfig.dbo_decode_token_threshold + dbo_prefill_token_threshold: int = \ + ParallelConfig.dbo_prefill_token_threshold eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb expert_placement_strategy: ExpertPlacementStrategy = \ @@ -698,6 +700,9 @@ class EngineArgs: parallel_group.add_argument( "--dbo-decode-token-threshold", **parallel_kwargs["dbo_decode_token_threshold"]) + parallel_group.add_argument( + "--dbo-prefill-token-threshold", + **parallel_kwargs["dbo_prefill_token_threshold"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) parallel_group.add_argument("--eplb-config", @@ -1316,6 +1321,7 @@ class EngineArgs: enable_expert_parallel=self.enable_expert_parallel, enable_dbo=self.enable_dbo, dbo_decode_token_threshold=self.dbo_decode_token_threshold, + dbo_prefill_token_threshold=self.dbo_prefill_token_threshold, enable_eplb=self.enable_eplb, eplb_config=self.eplb_config, expert_placement_strategy=self.expert_placement_strategy, diff --git a/vllm/envs.py b/vllm/envs.py index ee5efff8bcd9..f6eafe892ef2 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -189,6 +189,8 @@ if TYPE_CHECKING: VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" + VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024 + VLLM_DBO_COMM_SMS: int = 20 GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = [] VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None @@ -1392,6 +1394,15 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", "VLLM_OBJECT_STORAGE_SHM_BUFFER"), + # The size in MB of the buffers (NVL and RDMA) used by DeepEP + "VLLM_DEEPEP_BUFFER_SIZE_MB": + lambda: int(os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024")), + + # The number of SMs to allocate for communication kernels when running DBO + # the rest of the SMs on the device will be allocated to compute + "VLLM_DBO_COMM_SMS": + lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")), + # Valid values are container,code_interpreter,web_search_preview # ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter "GPT_OSS_SYSTEM_TOOL_MCP_LABELS": diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index a250a6218715..9e9a9afc18a0 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -12,6 +12,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) from vllm.utils import round_up +from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, dbo_enabled, dbo_switch_to_comm, + dbo_switch_to_compute, dbo_switch_to_compute_sync, + dbo_yield_and_switch_from_comm_to_compute, + dbo_yield_and_switch_from_compute_to_comm) class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -46,9 +51,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self.async_prepare = True # The dispatch function returns a handle that the combine function - # requires. We store the handle here so it is available to the - # combine function. - self.handle = None + # requires. Under DBO microbatching we must track one handle per + # micro-batch to avoid races between threads. + self.handles = [None, None] # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] @@ -89,6 +94,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): has_scales = token_scales is not None + # We yield before launching the dispatch kernel since the dispatch + # kernel will block the CPU so we want to queue up all the compute + # for the other ubatch before the dispatch kernel starts. + dbo_yield_and_switch_from_compute_to_comm() + (num_tokens_per_rank, num_tokens_per_rdma_rank, dispatch_expert_num_tokens, is_token_in_rank, event) = self.buffer.get_dispatch_layout( @@ -104,7 +114,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ( token_data, expert_topk_ids, expert_topk_weights, - expert_num_tokens_per_expert_list, self.handle, event + expert_num_tokens_per_expert_list, handle, event ) = self.buffer.dispatch( x=token_data, handle=None, @@ -119,9 +129,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_alignment=1, config=self._get_dispatch_config(), previous_event=None, - async_finish=self.async_prepare, + async_finish=self.async_prepare and not dbo_enabled(), allocate_on_comm_stream=False) + # record the handle for this ubatch + a2a_idx = dbo_current_ubatch_id() + self.handles[a2a_idx] = handle + + dbo_switch_to_compute_sync() + return lambda: self._receiver( event, has_scales, @@ -146,7 +162,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1_scale: Optional[torch.Tensor], quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - if self.async_prepare: + if event.event is not None: event.current_stream_wait() if has_scales: @@ -207,7 +223,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[Callable, mk.ReceiverType]: + ) -> mk.ReceiverType: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -233,14 +249,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1q_scale = None a1_post_scale = quant_config.a1_scale - return (lambda *args: None, - self._do_dispatch(tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts, - a1_scale=a1_post_scale, - quant_config=quant_config)) + return self._do_dispatch(tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts, + a1_scale=a1_post_scale, + quant_config=quant_config) def prepare( self, @@ -252,10 +267,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - (_, receiver) = self.prepare_async(a1, topk_weights, topk_ids, - num_experts, expert_map, - apply_router_weight_on_input, - quant_config) + receiver = self.prepare_async(a1, topk_weights, topk_ids, num_experts, + expert_map, apply_router_weight_on_input, + quant_config) return receiver() def _finalize( @@ -269,7 +283,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_async: bool, ) -> Optional[Callable]: - assert self.handle is not None + a2a_idx = dbo_current_ubatch_id() + handle = self.handles[a2a_idx] + assert handle is not None # fused_expert_output can have 0 tokens - This happens when none of the # tokens from the all2all reach this EP rank. @@ -283,25 +299,35 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_ids=topk_ids, apply_router_weight_on_input=apply_router_weight_on_input, ) - + dbo_yield_and_switch_from_compute_to_comm() combined_x, _, event = self.buffer.combine( x=fused_expert_output, - handle=self.handle, + handle=handle, topk_weights=None, config=self._get_combine_config(), previous_event=None, - async_finish=do_async, + async_finish=do_async and not dbo_enabled(), allocate_on_comm_stream=False) + dbo_switch_to_compute() + if do_async: def _receiver(): - event.current_stream_wait() + if event.event is not None: + event.current_stream_wait() + dbo_switch_to_comm() # Respect inplace outputs. output.copy_(combined_x, non_blocking=True) - return lambda: _receiver() + # TODO(lucas): refactor the modular kernel so this will be + # handled there + dbo_yield_and_switch_from_comm_to_compute() + + return _receiver else: + # TODO(lucas): support this case with the refactored modular kernel + assert not dbo_enabled() # Respect inplace outputs. output.copy_(combined_x, non_blocking=True) return None diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 101fc8798c42..a9554291db69 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -206,7 +206,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, do_async: bool, - ) -> Optional[Callable]: + ) -> tuple[Callable, Callable]: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") @@ -233,7 +233,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return_recv_hook=do_recv_hook, out=output) - return recv_hook + return recv_hook, lambda: None def finalize_async( self, @@ -243,8 +243,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> Callable: - recv_hook = self._finalize( + ) -> tuple[Callable, Callable]: + return self._finalize( output, fused_expert_output, topk_weights, @@ -253,8 +253,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): weight_and_reduce_impl, do_async=True, ) - assert recv_hook is not None - return recv_hook def finalize( self, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 5fce24018e64..4ba14196682a 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable _resize_cache, count_expert_num_tokens) from vllm.utils import cdiv -from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook, +from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, + dbo_maybe_run_recv_hook, dbo_register_recv_hook, dbo_yield) # @@ -223,7 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC): expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[Callable, ReceiverType]: + ) -> Union[tuple[Callable, ReceiverType], ReceiverType]: """ Perform any quantization (and/or) dispatching needed for this kernel but do not wait for results from other workers. @@ -239,10 +240,21 @@ class FusedMoEPrepareAndFinalize(ABC): - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. - Returns a callback that when invoked waits for results from other - workers and has the same return signature as `prepare`, e.g. + Returns a callback or a hook callback pair that when invoked waits for + results from other workers and has the same return signature as + `prepare`, if a hook is returned this is more lightweight check that + the recv is complete without doing extra work (used by DBO, will be + refactored in the very near future) + + e.g. - receiver = obj.prepare_async(...) + ret = obj.prepare_async(...) + + if isinstance(ret, tuple): + hook, receiver = ret + hook() + + if hook is not None: a, a_scales, expert_meta, topk_ids, topk_weights = receiver() is equivalent to: @@ -284,7 +296,7 @@ class FusedMoEPrepareAndFinalize(ABC): topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: TopKWeightAndReduce, - ) -> Callable: + ) -> Union[tuple[Callable, Callable], Callable]: """ Perform any combine plus apply weights and perform a reduction on the fused experts output but do not wait for results from other workers. @@ -298,11 +310,17 @@ class FusedMoEPrepareAndFinalize(ABC): - weight_and_reduce_impl: An optional TopKWeightAndReduce implementation. - Returns a callback that when invoked waits for results from other - workers and has the same return signature as `finalize`, e.g. + Returns a callback or a hook callback pair that when invoked waits for + results from other workers and has the same return signature as + `finalize`, if a hook is returned this is more lightweight check that + the recv is complete without doing extra work (used by DBO, will be + refactored in the very near future) - receiver = obj.finalize_async(output, ...) + ret = obj.finalize_async(output, ...) ... output not valid yet ... + if isinstance(ret, tuple): + hook, receiver = ret + hook() receiver() ... output valid here ... @@ -600,9 +618,23 @@ class FusedMoEModularKernel(torch.nn.Module): layer due to any layer specific state that may be used by the component objects. """ - fused_out_buffer = SharedResizableBuffer() - workspace13_buffer = SharedResizableBuffer() - workspace2_buffer = SharedResizableBuffer() + + class SharedBuffers: + + def __init__(self) -> None: + self.fused_out = SharedResizableBuffer() + self.workspace13 = SharedResizableBuffer() + self.workspace2 = SharedResizableBuffer() + + # Persistent buffers that are shared across `FusedMoEModularKernel` + # instances (layers), to save memory and allocattions. + # + # We have two sets of buffers to support dual batch overlap (DBO) where each + # microbatch (ubatch) should use its own set of buffers to avoid + # cross-ubatch contimination. + # NOTE that memory is lazily allocated for these buffers, meaning that if + # DBO isn't being used, the second SharedBuffers will be empty. + shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()] def __init__( self, @@ -647,14 +679,18 @@ class FusedMoEModularKernel(torch.nn.Module): a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, expert_tokens_meta) + # select per-ubatch buffers to avoid cross-ubatch reuse under DBO + ubatch_idx = dbo_current_ubatch_id() + buffers = self.shared_buffers[ubatch_idx] + # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13 = self.workspace13_buffer.get(workspace13_shape, - device=a1.device, - dtype=workspace_dtype) - workspace2 = self.workspace2_buffer.get(workspace2_shape, - device=a1.device, - dtype=workspace_dtype) + workspace13 = buffers.workspace13.get(workspace13_shape, + device=a1.device, + dtype=workspace_dtype) + workspace2 = buffers.workspace2.get(workspace2_shape, + device=a1.device, + dtype=workspace_dtype) assert fused_out is None or fused_out.shape == fused_out_shape, ( f"fused_out {fused_out.shape} but expected {fused_out_shape}") @@ -733,9 +769,11 @@ class FusedMoEModularKernel(torch.nn.Module): (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, expert_tokens_meta) - fused_out = self.fused_out_buffer.get(fused_out_shape, - device=a1q.device, - dtype=a1.dtype) + ubatch_idx = dbo_current_ubatch_id() + buffers = self.shared_buffers[ubatch_idx] + fused_out = buffers.fused_out.get(fused_out_shape, + device=a1q.device, + dtype=a1.dtype) def slice_input_tensors( chunk_idx: int @@ -868,6 +906,7 @@ class FusedMoEModularKernel(torch.nn.Module): if not self.prepare_finalize.supports_async(): # We shouldn't be running an a2a kernel that doesn't # support async prepare/finalize + # TODO(lucas): enable in follow-up assert not dbo_enabled() (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, @@ -883,7 +922,7 @@ class FusedMoEModularKernel(torch.nn.Module): else: # Overlap shared expert compute with all2all dispatch. dbo_maybe_run_recv_hook() - hook, receiver = self.prepare_finalize.prepare_async( + prepare_ret = self.prepare_finalize.prepare_async( a1, topk_weights, topk_ids, @@ -893,13 +932,21 @@ class FusedMoEModularKernel(torch.nn.Module): self.fused_experts.quant_config, ) - # If DBO is being used, register the hook with the ubatch context - # and call it in dbo_maybe_run_recv_hook instead of passing it to - # the receiver. - dbo_register_recv_hook(hook) - dbo_yield() - if not dbo_enabled(): - hook() + # TODO(lucas): refactor this in the alternative schedules followup + # currently unpack if we have hook + receiver pair or just + # receiver (see finalize_async docstring) + hook, receiver = prepare_ret \ + if isinstance(prepare_ret, tuple) else (None, prepare_ret) + + if hook is not None: + if dbo_enabled(): + # If DBO is being used, register the hook with the ubatch + # context and call it in dbo_maybe_run_recv_hook instead of + # passing it to the receiver. + dbo_register_recv_hook(hook) + dbo_yield() + else: + hook() (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, _expert_topk_weights) = receiver() @@ -952,7 +999,7 @@ class FusedMoEModularKernel(torch.nn.Module): if self.shared_experts is not None: shared_output = self.shared_experts(a1) else: - recv_hook = self.prepare_finalize.finalize_async( + finalize_ret = self.prepare_finalize.finalize_async( output, fused_out, topk_weights, @@ -964,11 +1011,23 @@ class FusedMoEModularKernel(torch.nn.Module): if self.shared_experts is not None: shared_output = self.shared_experts(a1) - assert recv_hook is not None - dbo_register_recv_hook(recv_hook) - dbo_yield() - if not dbo_enabled(): - recv_hook() + # TODO(lucas): refactor this in the alternative schedules followup + # currently unpack if we have hook + receiver pair or just + # receiver (see finalize_async docstring) + hook, receiver = finalize_ret \ + if isinstance(finalize_ret, tuple) else (None, finalize_ret) + + if hook is not None: + if dbo_enabled(): + # If DBO is being used, register the hook with the ubatch + # context and call it in dbo_maybe_run_recv_hook instead of + # passing it to the receiver. + dbo_register_recv_hook(hook) + dbo_yield() + else: + hook() + + receiver() if self.shared_experts is None: return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 6ef489f5a7a2..f837439f953e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -107,19 +107,57 @@ def _make_metadata_with_slice( the requests included in ubatch_slice """ + assert not ubatch_slice.is_empty(), ( + f"Ubatch slice {ubatch_slice} is empty") + request_slice = ubatch_slice.request_slice token_slice = ubatch_slice.token_slice + start_locs = attn_metadata.query_start_loc_cpu + first_req = request_slice.start + first_tok = token_slice.start + last_req = request_slice.stop - 1 + last_tok = token_slice.stop - 1 + + assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], \ + "Token slice start outside of first request" + assert start_locs[last_req] <= last_tok < start_locs[last_req+1], \ + "Token slice end outside of last request" + + # If the "middle" request has tokens in both ubatches, we have to split it. + # If ubatch_slice is the first ubatch then we will be splitting the last + # request. If it's the second microbatch, then we will be splitting the + # first request + splits_first_request = first_tok > start_locs[first_req] + splits_last_request = last_tok < start_locs[last_req + 1] - 1 + + query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice) query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc, request_slice) + assert len(query_start_loc) >= 2, ( f"query_start_loc must have at least 2 elements, " f"got {len(query_start_loc)}") - query_start_loc_cpu = slice_query_start_locs( - attn_metadata.query_start_loc_cpu, request_slice) + if splits_first_request: + tokens_skipped = first_tok - start_locs[first_req] + query_start_loc[1:] -= tokens_skipped + query_start_loc_cpu[1:] -= tokens_skipped seq_lens = attn_metadata.seq_lens[request_slice] seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + + if splits_last_request: + tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop + query_start_loc[-1] -= tokens_skipped + query_start_loc_cpu[-1] -= tokens_skipped + + # Make sure we don't modify the seq_lens tensors + # (not cudagraph compatible) + seq_lens = seq_lens.clone() + seq_lens_cpu = seq_lens_cpu.clone() + seq_lens[-1] -= tokens_skipped + seq_lens_cpu[-1] -= tokens_skipped + max_seq_len = int(seq_lens_cpu.max()) num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ request_slice] @@ -167,6 +205,7 @@ def split_attn_metadata( for ubatch_slice in ubatch_slices: results.append( _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + return results @@ -696,7 +735,6 @@ def split_decodes_and_prefills( return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() - assert torch.all(query_lens[first_prefill:] > decode_threshold) assert torch.all(query_lens[:first_prefill] <= decode_threshold) num_decodes = first_prefill num_prefills = num_reqs - num_decodes diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index dc97d5c8f39d..a0f40828d42f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -30,7 +30,6 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) @@ -192,9 +191,8 @@ class EagleProposer: assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups - ubatch_id = dbo_current_ubatch_id() attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata = attn_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0) @@ -330,7 +328,7 @@ class EagleProposer: # Rebuild attention metadata attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata = attn_metadata_builder\ .build_for_drafting(common_attn_metadata=common_attn_metadata, draft_index=token_index + 1) @@ -538,9 +536,8 @@ class EagleProposer: hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, ) -> list[torch.Tensor]: - ubatch_id = dbo_current_ubatch_id() tree_attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + self.runner.attn_groups[0][0].get_metadata_builder() assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 89b9a3c34f2a..ed324138c6fe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -96,7 +96,8 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split +from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds, + ubatch_split) from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices from vllm.v1.worker.utils import is_residual_scattered_for_sp @@ -1032,7 +1033,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens_padded = num_tokens_unpadded + self.get_local_padding( num_tokens_unpadded) ubatch_slices, num_tokens_after_padding = \ - ubatch_split(max_num_scheduled_tokens, + ubatch_split(num_scheduled_tokens, num_tokens_unpadded, num_tokens_padded, self.vllm_config) @@ -1206,7 +1207,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ubatch_slices, common_attn_metadata) for ubid, common_attn_metadata in enumerate( common_attn_metadata_list): - assert common_attn_metadata.max_query_len == 1 attn_metadata_i = (attn_group.get_metadata_builder( ubatch_id=ubid).build( common_prefix_len=common_prefix_len, @@ -2182,9 +2182,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) = self._preprocess(scheduler_output, intermediate_tensors, ubatch_slices, num_tokens_after_padding) - if ubatch_slices is not None: - num_input_tokens = num_input_tokens // 2 - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( num_scheduled_tokens @@ -2194,6 +2191,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode, batch_descriptor = \ self.cudagraph_dispatcher.dispatch(batch_descriptor) + # This is currently to get around the assert in the DPMetadata + # where it wants `num_tokens_across_dp` to align with `num_tokens` + if ubatch_slices is not None: + num_input_tokens = ubatch_slices[0].num_tokens + # Run the model. # Use persistent buffers for CUDA graphs. with (set_forward_context( @@ -2821,7 +2823,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, force_attention: bool = False, uniform_decode: bool = False, - allow_microbatching: bool = False, + allow_microbatching: bool = True, skip_eplb: bool = False, is_profile: bool = False, create_mixed_batch: bool = False, @@ -2847,32 +2849,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ - ubatch_enabled = self.parallel_config.enable_dbo - num_tokens_across_dp = None - num_pad = 0 - should_ubatch = False - if ubatch_enabled: - should_ubatch = num_tokens >= \ - self.parallel_config.dbo_decode_token_threshold and \ - allow_microbatching - - (should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch( - num_tokens, num_tokens, should_ubatch, self.vllm_config) - - # Currently the dummy run should only be ubatching during - # cuda graph capture, meaning all DP ranks should already - # have the same batch size - if num_tokens_across_dp is not None: - assert int(num_tokens_across_dp[0]) == num_tokens // 2 - assert cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } - if not should_ubatch: - num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) - num_tokens += num_pad - # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using # different graphs and/or modes for mixed prefill-decode batches vs. @@ -2888,10 +2868,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # for GQA/MQA. max_query_len = self.uniform_decode_query_len if uniform_decode else \ num_tokens - if allow_microbatching: - assert self.uniform_decode_query_len == 1 - assert uniform_decode is True - assert max_query_len == 1 # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -2930,20 +2906,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) ubatch_slices = None + num_tokens_after_padding = None + # We currently only microbatch if the number of tokens is # over a certain threshold. - if should_ubatch: - # We only support decode-only cudagraphs - assert num_reqs == num_tokens - assert num_tokens % 2 == 0 - ubatch_slices = [ - UBatchSlice(slice(0, num_reqs // 2), slice(0, - num_tokens // 2)), - UBatchSlice(slice(num_reqs // 2, num_reqs), - slice(num_tokens // 2, num_tokens)) - ] + if self.parallel_config.enable_dbo and allow_microbatching: + ubatch_slices, num_tokens_after_padding = ubatch_split( + num_scheduled_tokens, + total_num_scheduled_tokens, + total_num_scheduled_tokens, + self.vllm_config, + ) + + # If we failed to microbatch, currently need to resynchronize + # TODO(lucas,sage): we should be able to avoid this second sync by + # refactoring `get_dp_padding_ubatch` and `get_dp_padding` into + # a single `coordinate_batch_across_dp` function. + if num_tokens_after_padding is None: + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens_after_padding = num_tokens + num_pad + else: + num_tokens_across_dp = num_tokens_after_padding + num_tokens_after_padding = int(num_tokens_after_padding[0].item()) attn_metadata: Optional[PerLayerAttnMetadata] = None @@ -2966,6 +2953,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() + cum_num_tokens, _ = self._get_cumsum_and_arange( + num_scheduled_tokens) + self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() + for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): common_attn_metadata = CommonAttentionMetadata( @@ -3060,7 +3052,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, - num_tokens=num_tokens, + num_tokens=num_tokens_after_padding, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, @@ -3395,56 +3387,51 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", cudagraph_runtime_mode.name)) - enable_dbo = self.parallel_config.enable_dbo - # DBO Only supports running Full cudagraphs with uniform - # decode lengths - if enable_dbo and uniform_decode: - for num_tokens in compilation_cases: - # If the number of tokens is greater than the microbatching - # threshold, don't generate a microbatched cudagraph - if (num_tokens - < self.parallel_config.dbo_decode_token_threshold): - continue - # Warmup + # We skip EPLB here since we don't want to record dummy metrics + for num_tokens in compilation_cases: + # We currently only capture ubatched graphs when its a FULL + # cudagraph and for uniform decode batches. + capture_ubatched_graph = self.parallel_config.enable_dbo \ + and cudagraph_runtime_mode == CUDAGraphMode.FULL \ + and uniform_decode \ + and check_ubatch_thresholds( + config=self.vllm_config.parallel_config, + num_tokens=num_tokens, + uniform_decode=uniform_decode, + ) + + # Currently we capture both microbatched and non-microbatched + # graphs when capture_ubatched_graph is True, this is because + # occasionally we will be forced out of microbatching due to other + # DP ranks not microbatching (usually caused by an empty second + # microbatch; once we resolve this, we can remove the + # non-microbatched graph capture). + allow_microbatching_options = [True, False] if \ + capture_ubatched_graph else [False] + for allow_microbatching in allow_microbatching_options: for _ in range( self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. force_attention = ( cudagraph_runtime_mode == CUDAGraphMode.FULL) self._dummy_run(num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, - uniform_decode=True, - allow_microbatching=True, - skip_eplb=True) - - # Graph Capture + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False) self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True, - allow_microbatching=True, - skip_eplb=True) - # We skip EPLB here since we don't want to record dummy metrics - for num_tokens in compilation_cases: - for _ in range(self.compilation_config.cudagraph_num_of_warmups): - # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. - # But be careful, warm up with `NONE`is orthogonal to - # if we want to warm up attention or not. This is - # different from the case where `FULL` implies capture - # attention while `PIECEWISE` implies no attention. - force_attention = ( - cudagraph_runtime_mode == CUDAGraphMode.FULL) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, + cudagraph_runtime_mode=cudagraph_runtime_mode, uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, skip_eplb=True, remove_lora=False) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - skip_eplb=True, - remove_lora=False) self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: @@ -3500,24 +3487,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_groups: list[AttentionGroup] = [] for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): - attn_metadata_builders = [] - attn_metadata_builders.append(attn_backend.get_builder_cls()( - kv_cache_spec, + attn_group = AttentionGroup.create_with_metadata_builders( + attn_backend, layer_names, + kv_cache_spec, self.vllm_config, self.device, - )) - if self.parallel_config.enable_dbo: - attn_metadata_builders.append( - attn_backend.get_builder_cls()( - kv_cache_spec, - layer_names, - self.vllm_config, - self.device, - )) - attn_group = AttentionGroup(attn_backend, - attn_metadata_builders, - layer_names, kv_cache_spec) + num_metadata_builders=1 + if not self.parallel_config.enable_dbo else 2, + ) + attn_groups.append(attn_group) return attn_groups diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 5012ad0483c8..bfc3743ea417 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -1,25 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses import threading +from dataclasses import dataclass from typing import Any, Callable, Optional import torch +import vllm.envs as envs from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed import get_ep_group from vllm.forward_context import (create_forward_context, get_forward_context, override_forward_context) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import has_deep_gemm from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts logger = init_logger(__name__) -@dataclasses.dataclass +@dataclass class UbatchMetadata: context: UBatchContext input_ids: torch.Tensor @@ -29,13 +32,55 @@ class UbatchMetadata: num_tokens: int -@dataclasses.dataclass +@dataclass class CUDAGraphMetaData: cudagraph: torch.cuda.CUDAGraph ubatch_metadata: UbatchMetadata outputs: Optional[Any] = None +class SMControlContextManager: + + def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None], + set_compute_sms: Callable[[int], None]): + """ + Context manager for controlling SM (Streaming Multiprocessor) + allocation. Upon entering the context, it sets the number of SMs + allocated for communication and computation to comm_sms and + total_sms - comm_sms respectively. Upon exiting, it restores the + allocation to use all available SMs (i.e. total_sms). + + Args: + comm_sms (int): The number of SMs to allocate for communication. + (The remainder will be used for computation.) + set_comm_sms (Callable[[int], None]): + A function that sets the number of SMs for communication. + set_compute_sms (Callable[[int], None]): + A function that sets the number of SMs for computation. + """ + + assert current_platform.is_cuda(), \ + "SM control is currently only supported on CUDA" + + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + total_sms = props.multi_processor_count + + assert comm_sms < total_sms + self.total_sms = total_sms + self.compute_sms = total_sms - comm_sms + self.comm_sms = comm_sms + self.set_comm_sms = set_comm_sms + self.set_compute_sms = set_compute_sms + + def __enter__(self): + self.set_comm_sms(self.comm_sms) + self.set_compute_sms(self.compute_sms) + + def __exit__(self, exc_type, exc_value, traceback): + self.set_comm_sms(self.total_sms) + self.set_compute_sms(self.total_sms) + + class UBatchWrapper: def __init__(self, runnable: Callable, vllm_config: VllmConfig, @@ -56,6 +101,35 @@ class UBatchWrapper: runnable, vllm_config, runtime_mode=runtime_mode) self.graph_pool = current_platform.get_global_graph_pool() + self.sm_control = self._create_sm_control_context(vllm_config) + + @staticmethod + def _create_sm_control_context(vllm_config: VllmConfig): + comm_sms = envs.VLLM_DBO_COMM_SMS + + set_comm_sms = lambda sms: None + if vllm_config.parallel_config.enable_expert_parallel: + # Currently only DeepEP highthroughput supports SM control so this + # only affects that case. + all2all_manager = get_ep_group( + ).device_communicator.all2all_manager + + if all2all_manager.max_sms_used() is not None: + comm_sms = min(comm_sms, all2all_manager.max_sms_used()) + + if comm_sms > 0: + set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms) + + # TODO(lucas): support other kernels besides DeepGEMM + set_compute_sms = lambda sms: None + if has_deep_gemm() and comm_sms > 0: + import deep_gemm as dg + set_compute_sms = lambda sms: dg.set_num_sms(sms) + + return SMControlContextManager(comm_sms=comm_sms, + set_comm_sms=set_comm_sms, + set_compute_sms=set_compute_sms) + def __getattr__(self, key: str): # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): @@ -282,8 +356,8 @@ class UBatchWrapper: dp_metadata=dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE) - - return self._capture_ubatches(ubatch_metadata, self.model) + with self.sm_control: + return self._capture_ubatches(ubatch_metadata, self.model) elif num_tokens in self.cudagraphs: cudagraph_metadata = self.cudagraphs[num_tokens] cudagraph_metadata.cudagraph.replay() @@ -300,4 +374,5 @@ class UBatchWrapper: dp_metadata=dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE) - return self._run_ubatches(ubatch_metadata, self.model) + with self.sm_control: + return self._run_ubatches(ubatch_metadata, self.model) diff --git a/vllm/v1/worker/ubatch_splitting.py b/vllm/v1/worker/ubatch_splitting.py index 650f0ec5138d..30acb14ff58a 100644 --- a/vllm/v1/worker/ubatch_splitting.py +++ b/vllm/v1/worker/ubatch_splitting.py @@ -3,9 +3,10 @@ from typing import Optional +import numpy as np import torch -from vllm.config import VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.forward_context import DPMetadata from vllm.logger import init_logger from vllm.utils import round_up @@ -29,6 +30,16 @@ def should_ubatch_with_num_tokens( dp_size, dp_rank) +def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int, + uniform_decode: bool) -> bool: + if not config.enable_dbo: + return False + if uniform_decode: + return num_tokens >= config.dbo_decode_token_threshold + else: + return num_tokens >= config.dbo_prefill_token_threshold + + def get_dp_padding_ubatch( num_tokens_unpadded: int, num_tokens_padded: int, should_attempt_ubatching: bool, @@ -95,9 +106,37 @@ def get_dp_padding_ubatch( dtype=torch.int32) return should_ubatch, num_tokens_after_padding +def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \ + -> UBatchSlices: + # TODO(lucas): Refactor the gpu_model_runner.py so we can pass + # in cu_num_tokens directly (i.e. query_start_loc) + cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) + np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:]) + + first_ubatch_token_slice = slice(0, split_point) + second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1]) + + # Determine request slices using exclusive stop semantics + # First ubatch includes requests whose tokens overlap [0, split_point) + first_ubatch_req_stop = int( + np.searchsorted(cu_num_tokens, split_point, side="left")) + first_ubatch_req_slice = slice(0, first_ubatch_req_stop) + + # Second ubatch starts at the request that contains the split_point + # or the request starting exactly at split_point (if on boundary) + second_ubatch_req_start = int( + np.searchsorted(cu_num_tokens, split_point, side="right") - 1) + second_ubatch_req_slice = slice(second_ubatch_req_start, + len(cu_num_tokens) - 1) + + return [ + UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), + UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice) + ] + def ubatch_split( - max_num_scheduled_tokens: int, + num_scheduled_tokens_per_request: np.ndarray, num_tokens_unpadded: int, num_tokens_padded: int, vllm_config: VllmConfig, @@ -122,17 +161,20 @@ def ubatch_split( return (None, None) # Check preconditions for microbatching - should_attempt_ubatching = \ - parallel_config.enable_dbo and \ - num_tokens_unpadded >= \ - parallel_config.dbo_decode_token_threshold \ - and max_num_scheduled_tokens == 1 + should_attempt_ubatching = check_ubatch_thresholds( + parallel_config, + num_tokens_unpadded, + vllm_config, + ) # Don't microbatch unless every other DP worker is also microbatching - num_tokens_after_padding = None - (should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch( - num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching, - vllm_config) + should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch( + num_tokens_unpadded, + num_tokens_padded, + should_attempt_ubatching, + vllm_config, + ) + if not should_ubatch: return (None, None) @@ -141,15 +183,9 @@ def ubatch_split( # to the second ubatch in pad_out_ubatch_slice after attention # metadata creation assert num_tokens_after_padding is not None - total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item()) - padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch) - padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, - num_tokens_unpadded) + token_split_point = int(num_tokens_after_padding[0].item()) - # Note there's an assumption here that there's 1 token per request - ubatch_slices = [ - UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice), - UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice) - ] + ubatch_slices = create_ubatch_slices(num_scheduled_tokens_per_request, + token_split_point) return (ubatch_slices, num_tokens_after_padding) diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 6716d171cc70..33d58aa94843 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -10,6 +10,14 @@ class UBatchSlice: request_slice: slice token_slice: slice + def is_empty(self) -> bool: + return self.request_slice.start == self.request_slice.stop \ + or self.token_slice.start == self.token_slice.stop + + @property + def num_tokens(self) -> int: + return self.token_slice.stop - self.token_slice.start + UBatchSlices: TypeAlias = list[UBatchSlice] diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 9aeaa9909dc8..c26cb07123a5 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -51,8 +51,8 @@ class UBatchContext: self.cpu_wait_event.wait() self.cpu_wait_event.clear() self._restore_context() - # Assume we start on the compute stream - assert current_stream() == self.compute_stream + # Assume we want to start on the compute stream + self.update_stream(self.compute_stream) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -62,17 +62,15 @@ class UBatchContext: self.maybe_run_recv_hook() self.cpu_signal_event.set() self.cpu_wait_event.clear() - self.current_stream = self.compute_stream - torch.cuda.set_stream(self.current_stream) return False def _restore_context(self): forward_context._forward_context = self.forward_context - torch.cuda.set_stream(self.current_stream) def update_stream(self, stream): self.current_stream = stream - torch.cuda.set_stream(self.current_stream) + if current_stream() != self.current_stream: + torch.cuda.set_stream(self.current_stream) def _signal_comm_done(self): self.gpu_comm_done_event.record(self.comm_stream) @@ -99,9 +97,20 @@ class UBatchContext: self.cpu_wait_event.clear() self._restore_context() + def switch_to_comm(self): + self.update_stream(self.comm_stream) + + def switch_to_compute(self): + self.update_stream(self.compute_stream) + def switch_to_comm_sync(self): self._signal_compute_done() self.update_stream(self.comm_stream) + self._wait_compute_done() + + def switch_to_compute_sync(self): + self._signal_comm_done() + self.update_stream(self.compute_stream) self._wait_comm_done() def maybe_run_recv_hook(self): @@ -112,8 +121,7 @@ class UBatchContext: def yield_(self): self.current_stream = current_stream() self._cpu_yield() - if self.current_stream != current_stream(): - self.update_stream(self.current_stream) + self.update_stream(self.current_stream) def yield_and_switch_from_compute_to_comm(self): assert current_stream() == self.compute_stream @@ -153,15 +161,20 @@ def _register_ubatch_function(func): return wrapper +dbo_maybe_run_recv_hook = _register_ubatch_function( + UBatchContext.maybe_run_recv_hook) +dbo_yield = _register_ubatch_function(UBatchContext.yield_) dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function( UBatchContext.yield_and_switch_from_compute_to_comm) dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function( UBatchContext.yield_and_switch_from_comm_to_compute) -dbo_yield = _register_ubatch_function(UBatchContext.yield_) -dbo_maybe_run_recv_hook = _register_ubatch_function( - UBatchContext.maybe_run_recv_hook) +dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm) +dbo_switch_to_compute = _register_ubatch_function( + UBatchContext.switch_to_compute) dbo_switch_to_comm_sync = _register_ubatch_function( UBatchContext.switch_to_comm_sync) +dbo_switch_to_compute_sync = _register_ubatch_function( + UBatchContext.switch_to_compute_sync) def dbo_register_recv_hook(recv_hook): diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index af922f9979d1..553d33e27203 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -130,15 +130,32 @@ class MultiModalBudget: @dataclass class AttentionGroup: backend: type[AttentionBackend] + # When ubatching is enabled we will have a metadata builder for each ubatch + # so that if they use internal persistant buffers for cudagraphs, and they + # won't have to worry about conflicting with the other ubatches. metadata_builders: list[AttentionMetadataBuilder] layer_names: list[str] kv_cache_spec: KVCacheSpec + @staticmethod + def create_with_metadata_builders( + backend: type[AttentionBackend], + layer_names: list[str], + kv_cache_spec: KVCacheSpec, + vllm_config: VllmConfig, + device: torch.device, + num_metadata_builders: int = 1, + ) -> 'AttentionGroup': + metadata_builders = [ + backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, + device) + for _ in range(num_metadata_builders) + ] + return AttentionGroup(backend, metadata_builders, layer_names, + kv_cache_spec) + def get_metadata_builder(self, - ubatch_id: Optional[int] = None - ) -> AttentionMetadataBuilder: - if ubatch_id is None: - return self.metadata_builders[0] + ubatch_id: int = 0) -> AttentionMetadataBuilder: assert len(self.metadata_builders) > ubatch_id return self.metadata_builders[ubatch_id]