mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:25:01 +08:00
[Core/DBO][2/N] Dual-Batch Overlap add DeepEP High Throughput support and Prefill support (#24845)
Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
a903669e10
commit
cc1dc7ed6d
@ -5,11 +5,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
||||||
from tests.v1.attention.utils import create_common_attn_metadata
|
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||||
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
||||||
_make_metadata_with_slice,
|
_make_metadata_with_slice,
|
||||||
slice_query_start_locs,
|
slice_query_start_locs,
|
||||||
split_attn_metadata)
|
split_attn_metadata)
|
||||||
|
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -155,3 +156,83 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
|||||||
assert results[1].num_reqs == mid_point
|
assert results[1].num_reqs == mid_point
|
||||||
assert results[1].num_actual_tokens == mid_point
|
assert results[1].num_actual_tokens == mid_point
|
||||||
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
|
||||||
|
[
|
||||||
|
# Split in the middle of request 1
|
||||||
|
([32, 40], [8, 8], 12, 2, 1),
|
||||||
|
# Split inside the first request
|
||||||
|
([32, 40], [8, 8], 4, 1, 2),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point,
|
||||||
|
expected_first_reqs,
|
||||||
|
expected_second_reqs):
|
||||||
|
"""Test splitting a prefill across ubatches"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens)
|
||||||
|
common = create_common_attn_metadata(batch_spec,
|
||||||
|
block_size=16,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
num_scheduled_tokens = np.array(query_lens, dtype=np.int32)
|
||||||
|
qsl_np = common.query_start_loc_cpu.numpy()
|
||||||
|
num_tokens = common.num_actual_tokens
|
||||||
|
|
||||||
|
ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point)
|
||||||
|
assert len(ubatch_slices) == 2
|
||||||
|
|
||||||
|
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
|
||||||
|
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
|
||||||
|
|
||||||
|
# Token counts match the split
|
||||||
|
assert first_meta.num_actual_tokens == split_point
|
||||||
|
assert second_meta.num_actual_tokens == num_tokens - split_point
|
||||||
|
|
||||||
|
# Number of requests per ubatch
|
||||||
|
assert first_meta.num_reqs == expected_first_reqs
|
||||||
|
assert second_meta.num_reqs == expected_second_reqs
|
||||||
|
|
||||||
|
# Identify which request is split and how many tokens are in the first chunk
|
||||||
|
split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1)
|
||||||
|
tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx])
|
||||||
|
orig_q_lens = (common.query_start_loc_cpu[1:] -
|
||||||
|
common.query_start_loc_cpu[:-1])
|
||||||
|
|
||||||
|
# Check query length continuity: first-chunk + second-chunk == original qlen
|
||||||
|
# First ubatch last request query length
|
||||||
|
qlen_first_last = int(first_meta.query_start_loc_cpu[-1] -
|
||||||
|
first_meta.query_start_loc_cpu[-2])
|
||||||
|
# Second ubatch first request query length
|
||||||
|
qlen_second_first = int(second_meta.query_start_loc_cpu[1] -
|
||||||
|
second_meta.query_start_loc_cpu[0])
|
||||||
|
assert qlen_first_last == tokens_in_first_chunk
|
||||||
|
assert qlen_first_last + qlen_second_first == int(
|
||||||
|
orig_q_lens[split_req_idx])
|
||||||
|
|
||||||
|
# Check seq_lens adjustments
|
||||||
|
# Context lengths per original request
|
||||||
|
context_lens = [s - q for s, q in zip(seq_lens, query_lens)]
|
||||||
|
|
||||||
|
# First ubatch: last request's seq_len should be
|
||||||
|
# context + tokens_in_first_chunk
|
||||||
|
expected_seqlen = context_lens[split_req_idx] + tokens_in_first_chunk
|
||||||
|
assert int(first_meta.seq_lens[-1]) == expected_seqlen
|
||||||
|
|
||||||
|
# For full preceding requests in first ubatch, seq_lens should match
|
||||||
|
# originals
|
||||||
|
for i in range(first_meta.num_reqs - 1):
|
||||||
|
assert int(first_meta.seq_lens[i]) == seq_lens[i]
|
||||||
|
|
||||||
|
# Second ubatch: first request (continuation) seq_len should be full
|
||||||
|
# original
|
||||||
|
assert int(second_meta.seq_lens[0]) == seq_lens[split_req_idx]
|
||||||
|
# Any following full requests in second ubatch should match originals
|
||||||
|
for j in range(1, second_meta.num_reqs):
|
||||||
|
# Map to original request index
|
||||||
|
orig_idx = split_req_idx + j
|
||||||
|
assert int(second_meta.seq_lens[j]) == seq_lens[orig_idx]
|
||||||
|
|||||||
@ -532,9 +532,8 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
|||||||
# Mock runner for attention metadata building
|
# Mock runner for attention metadata building
|
||||||
proposer.runner = mock.MagicMock()
|
proposer.runner = mock.MagicMock()
|
||||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||||
proposer.runner.attn_groups[0][0].metadata_builders = [
|
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
|
||||||
attn_metadata_builder
|
attn_metadata_builder
|
||||||
]
|
|
||||||
|
|
||||||
result = proposer.propose(target_token_ids=target_token_ids,
|
result = proposer.propose(target_token_ids=target_token_ids,
|
||||||
target_positions=target_positions,
|
target_positions=target_positions,
|
||||||
@ -659,9 +658,8 @@ def test_propose_tree(spec_token_tree):
|
|||||||
# Mock runner for attention metadata building.
|
# Mock runner for attention metadata building.
|
||||||
proposer.runner = mock.MagicMock()
|
proposer.runner = mock.MagicMock()
|
||||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||||
proposer.runner.attn_groups[0][0].metadata_builders = [
|
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
|
||||||
attn_metadata_builder
|
attn_metadata_builder
|
||||||
]
|
|
||||||
|
|
||||||
# Setup inputs for the proposer.
|
# Setup inputs for the proposer.
|
||||||
target_token_ids = torch.randint(0,
|
target_token_ids = torch.randint(0,
|
||||||
|
|||||||
@ -638,11 +638,13 @@ class VllmConfig:
|
|||||||
|
|
||||||
if self.parallel_config.enable_dbo:
|
if self.parallel_config.enable_dbo:
|
||||||
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||||
assert a2a_backend == "deepep_low_latency", \
|
assert a2a_backend in \
|
||||||
"Microbatching currently only supports the deepep_low_latency "\
|
["deepep_low_latency", "deepep_high_throughput"], \
|
||||||
f"all2all backend. {a2a_backend} is not supported. To fix set "\
|
"Microbatching currently only supports the deepep_low_latency and "\
|
||||||
"the VLLM_ALL2ALL_BACKEND environment variable to "\
|
f"deepep_high_throughput all2all backend. {a2a_backend} is not "\
|
||||||
"deepep_low_latency and install the DeepEP kerenls."
|
"supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\
|
||||||
|
"variable to deepep_low_latency or deepep_high_throughput and "\
|
||||||
|
"install the DeepEP kernels."
|
||||||
|
|
||||||
if not self.instance_id:
|
if not self.instance_id:
|
||||||
self.instance_id = random_uuid()[:5]
|
self.instance_id = random_uuid()[:5]
|
||||||
|
|||||||
@ -139,12 +139,18 @@ class ParallelConfig:
|
|||||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||||
|
|
||||||
enable_dbo: bool = False
|
enable_dbo: bool = False
|
||||||
"""Enable microbatching for the model executor."""
|
"""Enable dual batch overlap for the model executor."""
|
||||||
|
|
||||||
dbo_decode_token_threshold: int = 32
|
dbo_decode_token_threshold: int = 32
|
||||||
"""The threshold for microbatching. If the number of tokens in the
|
"""The threshold for dual batch overlap for batches only containing decodes.
|
||||||
request is greater than this threshold, microbatching will be used.
|
If the number of tokens in the request is greater than this threshold,
|
||||||
Otherwise, the request will be processed in a single batch."""
|
microbatching will be used. Otherwise, the request will be processed in a
|
||||||
|
single batch."""
|
||||||
|
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
|
||||||
|
"""The threshold for dual batch overlap for batches that contain one or more
|
||||||
|
prefills. If the number of tokens in the request is greater than this
|
||||||
|
threshold, microbatching will be used. Otherwise, the request will be
|
||||||
|
processed in a single batch."""
|
||||||
|
|
||||||
ray_workers_use_nsight: bool = False
|
ray_workers_use_nsight: bool = False
|
||||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.distributed import get_dp_group
|
from vllm.distributed import get_dp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -200,12 +201,12 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
|
|
||||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||||
num_nvl_bytes = 1024 * 1024 * 1024
|
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||||
num_rdma_bytes = None
|
num_rdma_bytes = None
|
||||||
num_qps_per_rank = None
|
num_qps_per_rank = None
|
||||||
|
|
||||||
if self.internode:
|
if self.internode:
|
||||||
num_rdma_bytes = 1024 * 1024 * 1024
|
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||||
num_qps_per_rank = self.num_sms // 2
|
num_qps_per_rank = self.num_sms // 2
|
||||||
else:
|
else:
|
||||||
num_rdma_bytes = 0
|
num_rdma_bytes = 0
|
||||||
@ -230,13 +231,18 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||||
buffer_kwargs, deep_ep.Buffer)
|
buffer_kwargs, deep_ep.Buffer)
|
||||||
# It is dangerous to set num sms outside this function. num_sms is not
|
|
||||||
# a part of the hash-key that identifies this object. If we are in a
|
|
||||||
# situation where we make objects with different num_sms, the hash key
|
|
||||||
# in get_or_create must be updated.
|
|
||||||
handle.set_num_sms(self.num_sms)
|
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
def set_num_sms(self, num_sms: int):
|
||||||
|
import deep_ep
|
||||||
|
|
||||||
|
# Right now the buffers are sized for only what the kernels were
|
||||||
|
# created with. So we can only reduce the number of SMS used
|
||||||
|
# but not increase it.
|
||||||
|
if num_sms > self.num_sms:
|
||||||
|
num_sms = self.num_sms
|
||||||
|
deep_ep.Buffer.set_num_sms(num_sms)
|
||||||
|
|
||||||
|
|
||||||
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||||
"""
|
"""
|
||||||
@ -265,7 +271,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
import deep_ep
|
import deep_ep
|
||||||
|
|
||||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||||
num_nvl_bytes = 1024 * 1024 * 1024
|
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||||
num_qps_per_rank = num_local_experts
|
num_qps_per_rank = num_local_experts
|
||||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||||
@ -291,3 +297,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||||
buffer_kwargs, deep_ep.Buffer)
|
buffer_kwargs, deep_ep.Buffer)
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
# DeepEP LL uses RDMA so no SMs are used for communication
|
||||||
|
def max_sms_used(self) -> Optional[int]:
|
||||||
|
return 0
|
||||||
@ -60,6 +60,12 @@ class All2AllManagerBase:
|
|||||||
# and reuse it for the same config.
|
# and reuse it for the same config.
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_num_sms(self, num_sms: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def max_sms_used(self) -> Optional[int]:
|
||||||
|
return None # None means it could use the whole GPU
|
||||||
|
|
||||||
def dispatch(self, hidden_states: torch.Tensor,
|
def dispatch(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -330,6 +330,8 @@ class EngineArgs:
|
|||||||
enable_dbo: bool = ParallelConfig.enable_dbo
|
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||||
dbo_decode_token_threshold: int = \
|
dbo_decode_token_threshold: int = \
|
||||||
ParallelConfig.dbo_decode_token_threshold
|
ParallelConfig.dbo_decode_token_threshold
|
||||||
|
dbo_prefill_token_threshold: int = \
|
||||||
|
ParallelConfig.dbo_prefill_token_threshold
|
||||||
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
|
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
|
||||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||||
expert_placement_strategy: ExpertPlacementStrategy = \
|
expert_placement_strategy: ExpertPlacementStrategy = \
|
||||||
@ -698,6 +700,9 @@ class EngineArgs:
|
|||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--dbo-decode-token-threshold",
|
"--dbo-decode-token-threshold",
|
||||||
**parallel_kwargs["dbo_decode_token_threshold"])
|
**parallel_kwargs["dbo_decode_token_threshold"])
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--dbo-prefill-token-threshold",
|
||||||
|
**parallel_kwargs["dbo_prefill_token_threshold"])
|
||||||
parallel_group.add_argument("--enable-eplb",
|
parallel_group.add_argument("--enable-eplb",
|
||||||
**parallel_kwargs["enable_eplb"])
|
**parallel_kwargs["enable_eplb"])
|
||||||
parallel_group.add_argument("--eplb-config",
|
parallel_group.add_argument("--eplb-config",
|
||||||
@ -1316,6 +1321,7 @@ class EngineArgs:
|
|||||||
enable_expert_parallel=self.enable_expert_parallel,
|
enable_expert_parallel=self.enable_expert_parallel,
|
||||||
enable_dbo=self.enable_dbo,
|
enable_dbo=self.enable_dbo,
|
||||||
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
||||||
|
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
|
||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
eplb_config=self.eplb_config,
|
eplb_config=self.eplb_config,
|
||||||
expert_placement_strategy=self.expert_placement_strategy,
|
expert_placement_strategy=self.expert_placement_strategy,
|
||||||
|
|||||||
11
vllm/envs.py
11
vllm/envs.py
@ -189,6 +189,8 @@ if TYPE_CHECKING:
|
|||||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||||
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
||||||
|
VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024
|
||||||
|
VLLM_DBO_COMM_SMS: int = 20
|
||||||
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
||||||
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
||||||
|
|
||||||
@ -1392,6 +1394,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
|
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
|
||||||
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
|
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
|
||||||
|
|
||||||
|
# The size in MB of the buffers (NVL and RDMA) used by DeepEP
|
||||||
|
"VLLM_DEEPEP_BUFFER_SIZE_MB":
|
||||||
|
lambda: int(os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024")),
|
||||||
|
|
||||||
|
# The number of SMs to allocate for communication kernels when running DBO
|
||||||
|
# the rest of the SMs on the device will be allocated to compute
|
||||||
|
"VLLM_DBO_COMM_SMS":
|
||||||
|
lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")),
|
||||||
|
|
||||||
# Valid values are container,code_interpreter,web_search_preview
|
# Valid values are container,code_interpreter,web_search_preview
|
||||||
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
|
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
|
||||||
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":
|
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":
|
||||||
|
|||||||
@ -12,6 +12,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
|||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
moe_kernel_quantize_input)
|
moe_kernel_quantize_input)
|
||||||
from vllm.utils import round_up
|
from vllm.utils import round_up
|
||||||
|
from vllm.v1.worker.ubatching import (
|
||||||
|
dbo_current_ubatch_id, dbo_enabled, dbo_switch_to_comm,
|
||||||
|
dbo_switch_to_compute, dbo_switch_to_compute_sync,
|
||||||
|
dbo_yield_and_switch_from_comm_to_compute,
|
||||||
|
dbo_yield_and_switch_from_compute_to_comm)
|
||||||
|
|
||||||
|
|
||||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||||
@ -46,9 +51,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
self.async_prepare = True
|
self.async_prepare = True
|
||||||
|
|
||||||
# The dispatch function returns a handle that the combine function
|
# The dispatch function returns a handle that the combine function
|
||||||
# requires. We store the handle here so it is available to the
|
# requires. Under DBO microbatching we must track one handle per
|
||||||
# combine function.
|
# micro-batch to avoid races between threads.
|
||||||
self.handle = None
|
self.handles = [None, None]
|
||||||
|
|
||||||
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
|
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
|
||||||
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
|
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
|
||||||
@ -89,6 +94,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
has_scales = token_scales is not None
|
has_scales = token_scales is not None
|
||||||
|
|
||||||
|
# We yield before launching the dispatch kernel since the dispatch
|
||||||
|
# kernel will block the CPU so we want to queue up all the compute
|
||||||
|
# for the other ubatch before the dispatch kernel starts.
|
||||||
|
dbo_yield_and_switch_from_compute_to_comm()
|
||||||
|
|
||||||
(num_tokens_per_rank, num_tokens_per_rdma_rank,
|
(num_tokens_per_rank, num_tokens_per_rdma_rank,
|
||||||
dispatch_expert_num_tokens, is_token_in_rank,
|
dispatch_expert_num_tokens, is_token_in_rank,
|
||||||
event) = self.buffer.get_dispatch_layout(
|
event) = self.buffer.get_dispatch_layout(
|
||||||
@ -104,7 +114,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
(
|
(
|
||||||
token_data, expert_topk_ids, expert_topk_weights,
|
token_data, expert_topk_ids, expert_topk_weights,
|
||||||
expert_num_tokens_per_expert_list, self.handle, event
|
expert_num_tokens_per_expert_list, handle, event
|
||||||
) = self.buffer.dispatch(
|
) = self.buffer.dispatch(
|
||||||
x=token_data,
|
x=token_data,
|
||||||
handle=None,
|
handle=None,
|
||||||
@ -119,9 +129,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_alignment=1,
|
expert_alignment=1,
|
||||||
config=self._get_dispatch_config(),
|
config=self._get_dispatch_config(),
|
||||||
previous_event=None,
|
previous_event=None,
|
||||||
async_finish=self.async_prepare,
|
async_finish=self.async_prepare and not dbo_enabled(),
|
||||||
allocate_on_comm_stream=False)
|
allocate_on_comm_stream=False)
|
||||||
|
|
||||||
|
# record the handle for this ubatch
|
||||||
|
a2a_idx = dbo_current_ubatch_id()
|
||||||
|
self.handles[a2a_idx] = handle
|
||||||
|
|
||||||
|
dbo_switch_to_compute_sync()
|
||||||
|
|
||||||
return lambda: self._receiver(
|
return lambda: self._receiver(
|
||||||
event,
|
event,
|
||||||
has_scales,
|
has_scales,
|
||||||
@ -146,7 +162,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1_scale: Optional[torch.Tensor],
|
a1_scale: Optional[torch.Tensor],
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.PrepareResultType:
|
) -> mk.PrepareResultType:
|
||||||
if self.async_prepare:
|
if event.event is not None:
|
||||||
event.current_stream_wait()
|
event.current_stream_wait()
|
||||||
|
|
||||||
if has_scales:
|
if has_scales:
|
||||||
@ -207,7 +223,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> tuple[Callable, mk.ReceiverType]:
|
) -> mk.ReceiverType:
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
topk = topk_ids.size(1)
|
topk = topk_ids.size(1)
|
||||||
@ -233,14 +249,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1q_scale = None
|
a1q_scale = None
|
||||||
a1_post_scale = quant_config.a1_scale
|
a1_post_scale = quant_config.a1_scale
|
||||||
|
|
||||||
return (lambda *args: None,
|
return self._do_dispatch(tokens=a1q,
|
||||||
self._do_dispatch(tokens=a1q,
|
token_scales=a1q_scale,
|
||||||
token_scales=a1q_scale,
|
rank_topk_ids=topk_ids,
|
||||||
rank_topk_ids=topk_ids,
|
rank_topk_weights=topk_weights,
|
||||||
rank_topk_weights=topk_weights,
|
num_experts=num_experts,
|
||||||
num_experts=num_experts,
|
a1_scale=a1_post_scale,
|
||||||
a1_scale=a1_post_scale,
|
quant_config=quant_config)
|
||||||
quant_config=quant_config))
|
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
@ -252,10 +267,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.PrepareResultType:
|
) -> mk.PrepareResultType:
|
||||||
(_, receiver) = self.prepare_async(a1, topk_weights, topk_ids,
|
receiver = self.prepare_async(a1, topk_weights, topk_ids, num_experts,
|
||||||
num_experts, expert_map,
|
expert_map, apply_router_weight_on_input,
|
||||||
apply_router_weight_on_input,
|
quant_config)
|
||||||
quant_config)
|
|
||||||
return receiver()
|
return receiver()
|
||||||
|
|
||||||
def _finalize(
|
def _finalize(
|
||||||
@ -269,7 +283,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
do_async: bool,
|
do_async: bool,
|
||||||
) -> Optional[Callable]:
|
) -> Optional[Callable]:
|
||||||
|
|
||||||
assert self.handle is not None
|
a2a_idx = dbo_current_ubatch_id()
|
||||||
|
handle = self.handles[a2a_idx]
|
||||||
|
assert handle is not None
|
||||||
|
|
||||||
# fused_expert_output can have 0 tokens - This happens when none of the
|
# fused_expert_output can have 0 tokens - This happens when none of the
|
||||||
# tokens from the all2all reach this EP rank.
|
# tokens from the all2all reach this EP rank.
|
||||||
@ -283,25 +299,35 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
dbo_yield_and_switch_from_compute_to_comm()
|
||||||
combined_x, _, event = self.buffer.combine(
|
combined_x, _, event = self.buffer.combine(
|
||||||
x=fused_expert_output,
|
x=fused_expert_output,
|
||||||
handle=self.handle,
|
handle=handle,
|
||||||
topk_weights=None,
|
topk_weights=None,
|
||||||
config=self._get_combine_config(),
|
config=self._get_combine_config(),
|
||||||
previous_event=None,
|
previous_event=None,
|
||||||
async_finish=do_async,
|
async_finish=do_async and not dbo_enabled(),
|
||||||
allocate_on_comm_stream=False)
|
allocate_on_comm_stream=False)
|
||||||
|
|
||||||
|
dbo_switch_to_compute()
|
||||||
|
|
||||||
if do_async:
|
if do_async:
|
||||||
|
|
||||||
def _receiver():
|
def _receiver():
|
||||||
event.current_stream_wait()
|
if event.event is not None:
|
||||||
|
event.current_stream_wait()
|
||||||
|
dbo_switch_to_comm()
|
||||||
# Respect inplace outputs.
|
# Respect inplace outputs.
|
||||||
output.copy_(combined_x, non_blocking=True)
|
output.copy_(combined_x, non_blocking=True)
|
||||||
|
|
||||||
return lambda: _receiver()
|
# TODO(lucas): refactor the modular kernel so this will be
|
||||||
|
# handled there
|
||||||
|
dbo_yield_and_switch_from_comm_to_compute()
|
||||||
|
|
||||||
|
return _receiver
|
||||||
else:
|
else:
|
||||||
|
# TODO(lucas): support this case with the refactored modular kernel
|
||||||
|
assert not dbo_enabled()
|
||||||
# Respect inplace outputs.
|
# Respect inplace outputs.
|
||||||
output.copy_(combined_x, non_blocking=True)
|
output.copy_(combined_x, non_blocking=True)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -206,7 +206,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
do_async: bool,
|
do_async: bool,
|
||||||
) -> Optional[Callable]:
|
) -> tuple[Callable, Callable]:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||||
), ("Weight application and reduction happens in the combine kernel.")
|
), ("Weight application and reduction happens in the combine kernel.")
|
||||||
@ -233,7 +233,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return_recv_hook=do_recv_hook,
|
return_recv_hook=do_recv_hook,
|
||||||
out=output)
|
out=output)
|
||||||
|
|
||||||
return recv_hook
|
return recv_hook, lambda: None
|
||||||
|
|
||||||
def finalize_async(
|
def finalize_async(
|
||||||
self,
|
self,
|
||||||
@ -243,8 +243,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
) -> Callable:
|
) -> tuple[Callable, Callable]:
|
||||||
recv_hook = self._finalize(
|
return self._finalize(
|
||||||
output,
|
output,
|
||||||
fused_expert_output,
|
fused_expert_output,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
@ -253,8 +253,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
weight_and_reduce_impl,
|
weight_and_reduce_impl,
|
||||||
do_async=True,
|
do_async=True,
|
||||||
)
|
)
|
||||||
assert recv_hook is not None
|
|
||||||
return recv_hook
|
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
|||||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||||
_resize_cache, count_expert_num_tokens)
|
_resize_cache, count_expert_num_tokens)
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
|
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
|
||||||
|
dbo_maybe_run_recv_hook,
|
||||||
dbo_register_recv_hook, dbo_yield)
|
dbo_register_recv_hook, dbo_yield)
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -223,7 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> tuple[Callable, ReceiverType]:
|
) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
|
||||||
"""
|
"""
|
||||||
Perform any quantization (and/or) dispatching needed for this kernel
|
Perform any quantization (and/or) dispatching needed for this kernel
|
||||||
but do not wait for results from other workers.
|
but do not wait for results from other workers.
|
||||||
@ -239,10 +240,21 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
- apply_router_weight_on_input: When True, apply the weights to the
|
- apply_router_weight_on_input: When True, apply the weights to the
|
||||||
activations, before quantization + dispatching.
|
activations, before quantization + dispatching.
|
||||||
|
|
||||||
Returns a callback that when invoked waits for results from other
|
Returns a callback or a hook callback pair that when invoked waits for
|
||||||
workers and has the same return signature as `prepare`, e.g.
|
results from other workers and has the same return signature as
|
||||||
|
`prepare`, if a hook is returned this is more lightweight check that
|
||||||
|
the recv is complete without doing extra work (used by DBO, will be
|
||||||
|
refactored in the very near future)
|
||||||
|
|
||||||
|
e.g.
|
||||||
|
|
||||||
receiver = obj.prepare_async(...)
|
ret = obj.prepare_async(...)
|
||||||
|
|
||||||
|
if isinstance(ret, tuple):
|
||||||
|
hook, receiver = ret
|
||||||
|
hook()
|
||||||
|
|
||||||
|
if hook is not None:
|
||||||
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
|
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
|
||||||
|
|
||||||
is equivalent to:
|
is equivalent to:
|
||||||
@ -284,7 +296,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||||
) -> Callable:
|
) -> Union[tuple[Callable, Callable], Callable]:
|
||||||
"""
|
"""
|
||||||
Perform any combine plus apply weights and perform a reduction on the
|
Perform any combine plus apply weights and perform a reduction on the
|
||||||
fused experts output but do not wait for results from other workers.
|
fused experts output but do not wait for results from other workers.
|
||||||
@ -298,11 +310,17 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
- weight_and_reduce_impl: An optional TopKWeightAndReduce
|
- weight_and_reduce_impl: An optional TopKWeightAndReduce
|
||||||
implementation.
|
implementation.
|
||||||
|
|
||||||
Returns a callback that when invoked waits for results from other
|
Returns a callback or a hook callback pair that when invoked waits for
|
||||||
workers and has the same return signature as `finalize`, e.g.
|
results from other workers and has the same return signature as
|
||||||
|
`finalize`, if a hook is returned this is more lightweight check that
|
||||||
|
the recv is complete without doing extra work (used by DBO, will be
|
||||||
|
refactored in the very near future)
|
||||||
|
|
||||||
receiver = obj.finalize_async(output, ...)
|
ret = obj.finalize_async(output, ...)
|
||||||
... output not valid yet ...
|
... output not valid yet ...
|
||||||
|
if isinstance(ret, tuple):
|
||||||
|
hook, receiver = ret
|
||||||
|
hook()
|
||||||
receiver()
|
receiver()
|
||||||
... output valid here ...
|
... output valid here ...
|
||||||
|
|
||||||
@ -600,9 +618,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
layer due to any layer specific state that may be used by the component
|
layer due to any layer specific state that may be used by the component
|
||||||
objects.
|
objects.
|
||||||
"""
|
"""
|
||||||
fused_out_buffer = SharedResizableBuffer()
|
|
||||||
workspace13_buffer = SharedResizableBuffer()
|
class SharedBuffers:
|
||||||
workspace2_buffer = SharedResizableBuffer()
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.fused_out = SharedResizableBuffer()
|
||||||
|
self.workspace13 = SharedResizableBuffer()
|
||||||
|
self.workspace2 = SharedResizableBuffer()
|
||||||
|
|
||||||
|
# Persistent buffers that are shared across `FusedMoEModularKernel`
|
||||||
|
# instances (layers), to save memory and allocattions.
|
||||||
|
#
|
||||||
|
# We have two sets of buffers to support dual batch overlap (DBO) where each
|
||||||
|
# microbatch (ubatch) should use its own set of buffers to avoid
|
||||||
|
# cross-ubatch contimination.
|
||||||
|
# NOTE that memory is lazily allocated for these buffers, meaning that if
|
||||||
|
# DBO isn't being used, the second SharedBuffers will be empty.
|
||||||
|
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -647,14 +679,18 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||||
expert_tokens_meta)
|
expert_tokens_meta)
|
||||||
|
|
||||||
|
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
|
||||||
|
ubatch_idx = dbo_current_ubatch_id()
|
||||||
|
buffers = self.shared_buffers[ubatch_idx]
|
||||||
|
|
||||||
# We can reuse the memory between cache1 and cache3 because by the
|
# We can reuse the memory between cache1 and cache3 because by the
|
||||||
# time we need cache3, we're done with cache1.
|
# time we need cache3, we're done with cache1.
|
||||||
workspace13 = self.workspace13_buffer.get(workspace13_shape,
|
workspace13 = buffers.workspace13.get(workspace13_shape,
|
||||||
device=a1.device,
|
device=a1.device,
|
||||||
dtype=workspace_dtype)
|
dtype=workspace_dtype)
|
||||||
workspace2 = self.workspace2_buffer.get(workspace2_shape,
|
workspace2 = buffers.workspace2.get(workspace2_shape,
|
||||||
device=a1.device,
|
device=a1.device,
|
||||||
dtype=workspace_dtype)
|
dtype=workspace_dtype)
|
||||||
|
|
||||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||||
@ -733,9 +769,11 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||||
expert_tokens_meta)
|
expert_tokens_meta)
|
||||||
fused_out = self.fused_out_buffer.get(fused_out_shape,
|
ubatch_idx = dbo_current_ubatch_id()
|
||||||
device=a1q.device,
|
buffers = self.shared_buffers[ubatch_idx]
|
||||||
dtype=a1.dtype)
|
fused_out = buffers.fused_out.get(fused_out_shape,
|
||||||
|
device=a1q.device,
|
||||||
|
dtype=a1.dtype)
|
||||||
|
|
||||||
def slice_input_tensors(
|
def slice_input_tensors(
|
||||||
chunk_idx: int
|
chunk_idx: int
|
||||||
@ -868,6 +906,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
if not self.prepare_finalize.supports_async():
|
if not self.prepare_finalize.supports_async():
|
||||||
# We shouldn't be running an a2a kernel that doesn't
|
# We shouldn't be running an a2a kernel that doesn't
|
||||||
# support async prepare/finalize
|
# support async prepare/finalize
|
||||||
|
# TODO(lucas): enable in follow-up
|
||||||
assert not dbo_enabled()
|
assert not dbo_enabled()
|
||||||
|
|
||||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||||
@ -883,7 +922,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
# Overlap shared expert compute with all2all dispatch.
|
# Overlap shared expert compute with all2all dispatch.
|
||||||
dbo_maybe_run_recv_hook()
|
dbo_maybe_run_recv_hook()
|
||||||
hook, receiver = self.prepare_finalize.prepare_async(
|
prepare_ret = self.prepare_finalize.prepare_async(
|
||||||
a1,
|
a1,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@ -893,13 +932,21 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
self.fused_experts.quant_config,
|
self.fused_experts.quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If DBO is being used, register the hook with the ubatch context
|
# TODO(lucas): refactor this in the alternative schedules followup
|
||||||
# and call it in dbo_maybe_run_recv_hook instead of passing it to
|
# currently unpack if we have hook + receiver pair or just
|
||||||
# the receiver.
|
# receiver (see finalize_async docstring)
|
||||||
dbo_register_recv_hook(hook)
|
hook, receiver = prepare_ret \
|
||||||
dbo_yield()
|
if isinstance(prepare_ret, tuple) else (None, prepare_ret)
|
||||||
if not dbo_enabled():
|
|
||||||
hook()
|
if hook is not None:
|
||||||
|
if dbo_enabled():
|
||||||
|
# If DBO is being used, register the hook with the ubatch
|
||||||
|
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||||
|
# passing it to the receiver.
|
||||||
|
dbo_register_recv_hook(hook)
|
||||||
|
dbo_yield()
|
||||||
|
else:
|
||||||
|
hook()
|
||||||
|
|
||||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||||
_expert_topk_weights) = receiver()
|
_expert_topk_weights) = receiver()
|
||||||
@ -952,7 +999,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
if self.shared_experts is not None:
|
if self.shared_experts is not None:
|
||||||
shared_output = self.shared_experts(a1)
|
shared_output = self.shared_experts(a1)
|
||||||
else:
|
else:
|
||||||
recv_hook = self.prepare_finalize.finalize_async(
|
finalize_ret = self.prepare_finalize.finalize_async(
|
||||||
output,
|
output,
|
||||||
fused_out,
|
fused_out,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
@ -964,11 +1011,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
if self.shared_experts is not None:
|
if self.shared_experts is not None:
|
||||||
shared_output = self.shared_experts(a1)
|
shared_output = self.shared_experts(a1)
|
||||||
|
|
||||||
assert recv_hook is not None
|
# TODO(lucas): refactor this in the alternative schedules followup
|
||||||
dbo_register_recv_hook(recv_hook)
|
# currently unpack if we have hook + receiver pair or just
|
||||||
dbo_yield()
|
# receiver (see finalize_async docstring)
|
||||||
if not dbo_enabled():
|
hook, receiver = finalize_ret \
|
||||||
recv_hook()
|
if isinstance(finalize_ret, tuple) else (None, finalize_ret)
|
||||||
|
|
||||||
|
if hook is not None:
|
||||||
|
if dbo_enabled():
|
||||||
|
# If DBO is being used, register the hook with the ubatch
|
||||||
|
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||||
|
# passing it to the receiver.
|
||||||
|
dbo_register_recv_hook(hook)
|
||||||
|
dbo_yield()
|
||||||
|
else:
|
||||||
|
hook()
|
||||||
|
|
||||||
|
receiver()
|
||||||
|
|
||||||
if self.shared_experts is None:
|
if self.shared_experts is None:
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -107,19 +107,57 @@ def _make_metadata_with_slice(
|
|||||||
the requests included in ubatch_slice
|
the requests included in ubatch_slice
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert not ubatch_slice.is_empty(), (
|
||||||
|
f"Ubatch slice {ubatch_slice} is empty")
|
||||||
|
|
||||||
request_slice = ubatch_slice.request_slice
|
request_slice = ubatch_slice.request_slice
|
||||||
token_slice = ubatch_slice.token_slice
|
token_slice = ubatch_slice.token_slice
|
||||||
|
|
||||||
|
start_locs = attn_metadata.query_start_loc_cpu
|
||||||
|
first_req = request_slice.start
|
||||||
|
first_tok = token_slice.start
|
||||||
|
last_req = request_slice.stop - 1
|
||||||
|
last_tok = token_slice.stop - 1
|
||||||
|
|
||||||
|
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], \
|
||||||
|
"Token slice start outside of first request"
|
||||||
|
assert start_locs[last_req] <= last_tok < start_locs[last_req+1], \
|
||||||
|
"Token slice end outside of last request"
|
||||||
|
|
||||||
|
# If the "middle" request has tokens in both ubatches, we have to split it.
|
||||||
|
# If ubatch_slice is the first ubatch then we will be splitting the last
|
||||||
|
# request. If it's the second microbatch, then we will be splitting the
|
||||||
|
# first request
|
||||||
|
splits_first_request = first_tok > start_locs[first_req]
|
||||||
|
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
||||||
|
|
||||||
|
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
|
||||||
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
|
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
|
||||||
request_slice)
|
request_slice)
|
||||||
|
|
||||||
assert len(query_start_loc) >= 2, (
|
assert len(query_start_loc) >= 2, (
|
||||||
f"query_start_loc must have at least 2 elements, "
|
f"query_start_loc must have at least 2 elements, "
|
||||||
f"got {len(query_start_loc)}")
|
f"got {len(query_start_loc)}")
|
||||||
query_start_loc_cpu = slice_query_start_locs(
|
|
||||||
attn_metadata.query_start_loc_cpu, request_slice)
|
|
||||||
|
|
||||||
|
if splits_first_request:
|
||||||
|
tokens_skipped = first_tok - start_locs[first_req]
|
||||||
|
query_start_loc[1:] -= tokens_skipped
|
||||||
|
query_start_loc_cpu[1:] -= tokens_skipped
|
||||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||||
|
|
||||||
|
if splits_last_request:
|
||||||
|
tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
|
||||||
|
query_start_loc[-1] -= tokens_skipped
|
||||||
|
query_start_loc_cpu[-1] -= tokens_skipped
|
||||||
|
|
||||||
|
# Make sure we don't modify the seq_lens tensors
|
||||||
|
# (not cudagraph compatible)
|
||||||
|
seq_lens = seq_lens.clone()
|
||||||
|
seq_lens_cpu = seq_lens_cpu.clone()
|
||||||
|
seq_lens[-1] -= tokens_skipped
|
||||||
|
seq_lens_cpu[-1] -= tokens_skipped
|
||||||
|
|
||||||
max_seq_len = int(seq_lens_cpu.max())
|
max_seq_len = int(seq_lens_cpu.max())
|
||||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
|
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
|
||||||
request_slice]
|
request_slice]
|
||||||
@ -167,6 +205,7 @@ def split_attn_metadata(
|
|||||||
for ubatch_slice in ubatch_slices:
|
for ubatch_slice in ubatch_slices:
|
||||||
results.append(
|
results.append(
|
||||||
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -696,7 +735,6 @@ def split_decodes_and_prefills(
|
|||||||
return num_reqs, 0, num_tokens, 0
|
return num_reqs, 0, num_tokens, 0
|
||||||
|
|
||||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||||
assert torch.all(query_lens[first_prefill:] > decode_threshold)
|
|
||||||
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
||||||
num_decodes = first_prefill
|
num_decodes = first_prefill
|
||||||
num_prefills = num_reqs - num_decodes
|
num_prefills = num_reqs - num_decodes
|
||||||
|
|||||||
@ -30,7 +30,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
|||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -192,9 +191,8 @@ class EagleProposer:
|
|||||||
assert self.runner is not None
|
assert self.runner is not None
|
||||||
|
|
||||||
# FIXME: need to consider multiple kv_cache_groups
|
# FIXME: need to consider multiple kv_cache_groups
|
||||||
ubatch_id = dbo_current_ubatch_id()
|
|
||||||
attn_metadata_builder = \
|
attn_metadata_builder = \
|
||||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
self.runner.attn_groups[0][0].get_metadata_builder()
|
||||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||||
common_attn_metadata=common_attn_metadata, draft_index=0)
|
common_attn_metadata=common_attn_metadata, draft_index=0)
|
||||||
|
|
||||||
@ -330,7 +328,7 @@ class EagleProposer:
|
|||||||
|
|
||||||
# Rebuild attention metadata
|
# Rebuild attention metadata
|
||||||
attn_metadata_builder = \
|
attn_metadata_builder = \
|
||||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
self.runner.attn_groups[0][0].get_metadata_builder()
|
||||||
attn_metadata = attn_metadata_builder\
|
attn_metadata = attn_metadata_builder\
|
||||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||||
draft_index=token_index + 1)
|
draft_index=token_index + 1)
|
||||||
@ -538,9 +536,8 @@ class EagleProposer:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
ubatch_id = dbo_current_ubatch_id()
|
|
||||||
tree_attn_metadata_builder = \
|
tree_attn_metadata_builder = \
|
||||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
self.runner.attn_groups[0][0].get_metadata_builder()
|
||||||
assert isinstance(tree_attn_metadata_builder,
|
assert isinstance(tree_attn_metadata_builder,
|
||||||
TreeAttentionMetadataBuilder)
|
TreeAttentionMetadataBuilder)
|
||||||
|
|
||||||
|
|||||||
@ -96,7 +96,8 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
|||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||||
KVConnectorModelRunnerMixin)
|
KVConnectorModelRunnerMixin)
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split
|
from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds,
|
||||||
|
ubatch_split)
|
||||||
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
|
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
|
||||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||||
|
|
||||||
@ -1032,7 +1033,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
|
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
|
||||||
num_tokens_unpadded)
|
num_tokens_unpadded)
|
||||||
ubatch_slices, num_tokens_after_padding = \
|
ubatch_slices, num_tokens_after_padding = \
|
||||||
ubatch_split(max_num_scheduled_tokens,
|
ubatch_split(num_scheduled_tokens,
|
||||||
num_tokens_unpadded,
|
num_tokens_unpadded,
|
||||||
num_tokens_padded,
|
num_tokens_padded,
|
||||||
self.vllm_config)
|
self.vllm_config)
|
||||||
@ -1206,7 +1207,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
ubatch_slices, common_attn_metadata)
|
ubatch_slices, common_attn_metadata)
|
||||||
for ubid, common_attn_metadata in enumerate(
|
for ubid, common_attn_metadata in enumerate(
|
||||||
common_attn_metadata_list):
|
common_attn_metadata_list):
|
||||||
assert common_attn_metadata.max_query_len == 1
|
|
||||||
attn_metadata_i = (attn_group.get_metadata_builder(
|
attn_metadata_i = (attn_group.get_metadata_builder(
|
||||||
ubatch_id=ubid).build(
|
ubatch_id=ubid).build(
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
@ -2182,9 +2182,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
) = self._preprocess(scheduler_output, intermediate_tensors,
|
) = self._preprocess(scheduler_output, intermediate_tensors,
|
||||||
ubatch_slices, num_tokens_after_padding)
|
ubatch_slices, num_tokens_after_padding)
|
||||||
|
|
||||||
if ubatch_slices is not None:
|
|
||||||
num_input_tokens = num_input_tokens // 2
|
|
||||||
|
|
||||||
uniform_decode = (max_query_len
|
uniform_decode = (max_query_len
|
||||||
== self.uniform_decode_query_len) and (
|
== self.uniform_decode_query_len) and (
|
||||||
num_scheduled_tokens
|
num_scheduled_tokens
|
||||||
@ -2194,6 +2191,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
cudagraph_runtime_mode, batch_descriptor = \
|
cudagraph_runtime_mode, batch_descriptor = \
|
||||||
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
||||||
|
|
||||||
|
# This is currently to get around the assert in the DPMetadata
|
||||||
|
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
||||||
|
if ubatch_slices is not None:
|
||||||
|
num_input_tokens = ubatch_slices[0].num_tokens
|
||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
# Use persistent buffers for CUDA graphs.
|
# Use persistent buffers for CUDA graphs.
|
||||||
with (set_forward_context(
|
with (set_forward_context(
|
||||||
@ -2821,7 +2823,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
force_attention: bool = False,
|
force_attention: bool = False,
|
||||||
uniform_decode: bool = False,
|
uniform_decode: bool = False,
|
||||||
allow_microbatching: bool = False,
|
allow_microbatching: bool = True,
|
||||||
skip_eplb: bool = False,
|
skip_eplb: bool = False,
|
||||||
is_profile: bool = False,
|
is_profile: bool = False,
|
||||||
create_mixed_batch: bool = False,
|
create_mixed_batch: bool = False,
|
||||||
@ -2847,32 +2849,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
(1 token) and prefill (multiple tokens) requests.
|
(1 token) and prefill (multiple tokens) requests.
|
||||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||||
"""
|
"""
|
||||||
ubatch_enabled = self.parallel_config.enable_dbo
|
|
||||||
num_tokens_across_dp = None
|
|
||||||
num_pad = 0
|
|
||||||
should_ubatch = False
|
|
||||||
if ubatch_enabled:
|
|
||||||
should_ubatch = num_tokens >= \
|
|
||||||
self.parallel_config.dbo_decode_token_threshold and \
|
|
||||||
allow_microbatching
|
|
||||||
|
|
||||||
(should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch(
|
|
||||||
num_tokens, num_tokens, should_ubatch, self.vllm_config)
|
|
||||||
|
|
||||||
# Currently the dummy run should only be ubatching during
|
|
||||||
# cuda graph capture, meaning all DP ranks should already
|
|
||||||
# have the same batch size
|
|
||||||
if num_tokens_across_dp is not None:
|
|
||||||
assert int(num_tokens_across_dp[0]) == num_tokens // 2
|
|
||||||
|
|
||||||
assert cudagraph_runtime_mode in {
|
assert cudagraph_runtime_mode in {
|
||||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||||
}
|
}
|
||||||
|
|
||||||
if not should_ubatch:
|
|
||||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
|
||||||
num_tokens += num_pad
|
|
||||||
|
|
||||||
# If cudagraph_mode.decode_mode() == FULL and
|
# If cudagraph_mode.decode_mode() == FULL and
|
||||||
# cudagraph_mode.separate_routine(). This means that we are using
|
# cudagraph_mode.separate_routine(). This means that we are using
|
||||||
# different graphs and/or modes for mixed prefill-decode batches vs.
|
# different graphs and/or modes for mixed prefill-decode batches vs.
|
||||||
@ -2888,10 +2868,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# for GQA/MQA.
|
# for GQA/MQA.
|
||||||
max_query_len = self.uniform_decode_query_len if uniform_decode else \
|
max_query_len = self.uniform_decode_query_len if uniform_decode else \
|
||||||
num_tokens
|
num_tokens
|
||||||
if allow_microbatching:
|
|
||||||
assert self.uniform_decode_query_len == 1
|
|
||||||
assert uniform_decode is True
|
|
||||||
assert max_query_len == 1
|
|
||||||
|
|
||||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||||
# for dummy run with LoRA so that the num_reqs collectively
|
# for dummy run with LoRA so that the num_reqs collectively
|
||||||
@ -2930,20 +2906,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
assert len(num_scheduled_tokens_list) == num_reqs
|
assert len(num_scheduled_tokens_list) == num_reqs
|
||||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
|
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
|
||||||
|
|
||||||
ubatch_slices = None
|
ubatch_slices = None
|
||||||
|
num_tokens_after_padding = None
|
||||||
|
|
||||||
# We currently only microbatch if the number of tokens is
|
# We currently only microbatch if the number of tokens is
|
||||||
# over a certain threshold.
|
# over a certain threshold.
|
||||||
if should_ubatch:
|
if self.parallel_config.enable_dbo and allow_microbatching:
|
||||||
# We only support decode-only cudagraphs
|
ubatch_slices, num_tokens_after_padding = ubatch_split(
|
||||||
assert num_reqs == num_tokens
|
num_scheduled_tokens,
|
||||||
assert num_tokens % 2 == 0
|
total_num_scheduled_tokens,
|
||||||
ubatch_slices = [
|
total_num_scheduled_tokens,
|
||||||
UBatchSlice(slice(0, num_reqs // 2), slice(0,
|
self.vllm_config,
|
||||||
num_tokens // 2)),
|
)
|
||||||
UBatchSlice(slice(num_reqs // 2, num_reqs),
|
|
||||||
slice(num_tokens // 2, num_tokens))
|
# If we failed to microbatch, currently need to resynchronize
|
||||||
]
|
# TODO(lucas,sage): we should be able to avoid this second sync by
|
||||||
|
# refactoring `get_dp_padding_ubatch` and `get_dp_padding` into
|
||||||
|
# a single `coordinate_batch_across_dp` function.
|
||||||
|
if num_tokens_after_padding is None:
|
||||||
|
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||||
|
num_tokens_after_padding = num_tokens + num_pad
|
||||||
|
else:
|
||||||
|
num_tokens_across_dp = num_tokens_after_padding
|
||||||
|
num_tokens_after_padding = int(num_tokens_after_padding[0].item())
|
||||||
|
|
||||||
attn_metadata: Optional[PerLayerAttnMetadata] = None
|
attn_metadata: Optional[PerLayerAttnMetadata] = None
|
||||||
|
|
||||||
@ -2966,6 +2953,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.seq_lens.np[num_reqs:] = 0
|
self.seq_lens.np[num_reqs:] = 0
|
||||||
self.seq_lens.copy_to_gpu()
|
self.seq_lens.copy_to_gpu()
|
||||||
|
|
||||||
|
cum_num_tokens, _ = self._get_cumsum_and_arange(
|
||||||
|
num_scheduled_tokens)
|
||||||
|
self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens
|
||||||
|
self.query_start_loc.copy_to_gpu()
|
||||||
|
|
||||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||||
self.kv_cache_config.kv_cache_groups):
|
self.kv_cache_config.kv_cache_groups):
|
||||||
common_attn_metadata = CommonAttentionMetadata(
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
@ -3060,7 +3052,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens_after_padding,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
@ -3395,56 +3387,51 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
desc="Capturing CUDA graphs ({}, {})".format(
|
desc="Capturing CUDA graphs ({}, {})".format(
|
||||||
"decode" if uniform_decode else "mixed prefill-decode",
|
"decode" if uniform_decode else "mixed prefill-decode",
|
||||||
cudagraph_runtime_mode.name))
|
cudagraph_runtime_mode.name))
|
||||||
enable_dbo = self.parallel_config.enable_dbo
|
|
||||||
# DBO Only supports running Full cudagraphs with uniform
|
|
||||||
# decode lengths
|
|
||||||
if enable_dbo and uniform_decode:
|
|
||||||
for num_tokens in compilation_cases:
|
|
||||||
# If the number of tokens is greater than the microbatching
|
|
||||||
# threshold, don't generate a microbatched cudagraph
|
|
||||||
if (num_tokens
|
|
||||||
< self.parallel_config.dbo_decode_token_threshold):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Warmup
|
# We skip EPLB here since we don't want to record dummy metrics
|
||||||
|
for num_tokens in compilation_cases:
|
||||||
|
# We currently only capture ubatched graphs when its a FULL
|
||||||
|
# cudagraph and for uniform decode batches.
|
||||||
|
capture_ubatched_graph = self.parallel_config.enable_dbo \
|
||||||
|
and cudagraph_runtime_mode == CUDAGraphMode.FULL \
|
||||||
|
and uniform_decode \
|
||||||
|
and check_ubatch_thresholds(
|
||||||
|
config=self.vllm_config.parallel_config,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
uniform_decode=uniform_decode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Currently we capture both microbatched and non-microbatched
|
||||||
|
# graphs when capture_ubatched_graph is True, this is because
|
||||||
|
# occasionally we will be forced out of microbatching due to other
|
||||||
|
# DP ranks not microbatching (usually caused by an empty second
|
||||||
|
# microbatch; once we resolve this, we can remove the
|
||||||
|
# non-microbatched graph capture).
|
||||||
|
allow_microbatching_options = [True, False] if \
|
||||||
|
capture_ubatched_graph else [False]
|
||||||
|
for allow_microbatching in allow_microbatching_options:
|
||||||
for _ in range(
|
for _ in range(
|
||||||
self.compilation_config.cudagraph_num_of_warmups):
|
self.compilation_config.cudagraph_num_of_warmups):
|
||||||
|
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
||||||
|
# But be careful, warm up with `NONE`is orthogonal to
|
||||||
|
# if we want to warm up attention or not. This is
|
||||||
|
# different from the case where `FULL` implies capture
|
||||||
|
# attention while `PIECEWISE` implies no attention.
|
||||||
force_attention = (
|
force_attention = (
|
||||||
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
||||||
self._dummy_run(num_tokens,
|
self._dummy_run(num_tokens,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||||
force_attention=force_attention,
|
force_attention=force_attention,
|
||||||
uniform_decode=True,
|
uniform_decode=uniform_decode,
|
||||||
allow_microbatching=True,
|
allow_microbatching=allow_microbatching,
|
||||||
skip_eplb=True)
|
skip_eplb=True,
|
||||||
|
remove_lora=False)
|
||||||
# Graph Capture
|
|
||||||
self._dummy_run(num_tokens,
|
self._dummy_run(num_tokens,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
uniform_decode=True,
|
|
||||||
allow_microbatching=True,
|
|
||||||
skip_eplb=True)
|
|
||||||
# We skip EPLB here since we don't want to record dummy metrics
|
|
||||||
for num_tokens in compilation_cases:
|
|
||||||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
|
||||||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
|
||||||
# But be careful, warm up with `NONE`is orthogonal to
|
|
||||||
# if we want to warm up attention or not. This is
|
|
||||||
# different from the case where `FULL` implies capture
|
|
||||||
# attention while `PIECEWISE` implies no attention.
|
|
||||||
force_attention = (
|
|
||||||
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
|
||||||
self._dummy_run(num_tokens,
|
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
|
||||||
force_attention=force_attention,
|
|
||||||
uniform_decode=uniform_decode,
|
uniform_decode=uniform_decode,
|
||||||
|
allow_microbatching=allow_microbatching,
|
||||||
skip_eplb=True,
|
skip_eplb=True,
|
||||||
remove_lora=False)
|
remove_lora=False)
|
||||||
self._dummy_run(num_tokens,
|
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
|
||||||
uniform_decode=uniform_decode,
|
|
||||||
skip_eplb=True,
|
|
||||||
remove_lora=False)
|
|
||||||
self.maybe_remove_all_loras(self.lora_config)
|
self.maybe_remove_all_loras(self.lora_config)
|
||||||
|
|
||||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
@ -3500,24 +3487,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
attn_groups: list[AttentionGroup] = []
|
attn_groups: list[AttentionGroup] = []
|
||||||
for (attn_backend,
|
for (attn_backend,
|
||||||
kv_cache_spec), layer_names in attn_backends_map.items():
|
kv_cache_spec), layer_names in attn_backends_map.items():
|
||||||
attn_metadata_builders = []
|
attn_group = AttentionGroup.create_with_metadata_builders(
|
||||||
attn_metadata_builders.append(attn_backend.get_builder_cls()(
|
attn_backend,
|
||||||
kv_cache_spec,
|
|
||||||
layer_names,
|
layer_names,
|
||||||
|
kv_cache_spec,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.device,
|
self.device,
|
||||||
))
|
num_metadata_builders=1
|
||||||
if self.parallel_config.enable_dbo:
|
if not self.parallel_config.enable_dbo else 2,
|
||||||
attn_metadata_builders.append(
|
)
|
||||||
attn_backend.get_builder_cls()(
|
|
||||||
kv_cache_spec,
|
|
||||||
layer_names,
|
|
||||||
self.vllm_config,
|
|
||||||
self.device,
|
|
||||||
))
|
|
||||||
attn_group = AttentionGroup(attn_backend,
|
|
||||||
attn_metadata_builders,
|
|
||||||
layer_names, kv_cache_spec)
|
|
||||||
attn_groups.append(attn_group)
|
attn_groups.append(attn_group)
|
||||||
return attn_groups
|
return attn_groups
|
||||||
|
|
||||||
|
|||||||
@ -1,25 +1,28 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import threading
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
|
from vllm.distributed import get_ep_group
|
||||||
from vllm.forward_context import (create_forward_context, get_forward_context,
|
from vllm.forward_context import (create_forward_context, get_forward_context,
|
||||||
override_forward_context)
|
override_forward_context)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import has_deep_gemm
|
||||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclass
|
||||||
class UbatchMetadata:
|
class UbatchMetadata:
|
||||||
context: UBatchContext
|
context: UBatchContext
|
||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
@ -29,13 +32,55 @@ class UbatchMetadata:
|
|||||||
num_tokens: int
|
num_tokens: int
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclass
|
||||||
class CUDAGraphMetaData:
|
class CUDAGraphMetaData:
|
||||||
cudagraph: torch.cuda.CUDAGraph
|
cudagraph: torch.cuda.CUDAGraph
|
||||||
ubatch_metadata: UbatchMetadata
|
ubatch_metadata: UbatchMetadata
|
||||||
outputs: Optional[Any] = None
|
outputs: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SMControlContextManager:
|
||||||
|
|
||||||
|
def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None],
|
||||||
|
set_compute_sms: Callable[[int], None]):
|
||||||
|
"""
|
||||||
|
Context manager for controlling SM (Streaming Multiprocessor)
|
||||||
|
allocation. Upon entering the context, it sets the number of SMs
|
||||||
|
allocated for communication and computation to comm_sms and
|
||||||
|
total_sms - comm_sms respectively. Upon exiting, it restores the
|
||||||
|
allocation to use all available SMs (i.e. total_sms).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
comm_sms (int): The number of SMs to allocate for communication.
|
||||||
|
(The remainder will be used for computation.)
|
||||||
|
set_comm_sms (Callable[[int], None]):
|
||||||
|
A function that sets the number of SMs for communication.
|
||||||
|
set_compute_sms (Callable[[int], None]):
|
||||||
|
A function that sets the number of SMs for computation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert current_platform.is_cuda(), \
|
||||||
|
"SM control is currently only supported on CUDA"
|
||||||
|
|
||||||
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||||
|
total_sms = props.multi_processor_count
|
||||||
|
|
||||||
|
assert comm_sms < total_sms
|
||||||
|
self.total_sms = total_sms
|
||||||
|
self.compute_sms = total_sms - comm_sms
|
||||||
|
self.comm_sms = comm_sms
|
||||||
|
self.set_comm_sms = set_comm_sms
|
||||||
|
self.set_compute_sms = set_compute_sms
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.set_comm_sms(self.comm_sms)
|
||||||
|
self.set_compute_sms(self.compute_sms)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.set_comm_sms(self.total_sms)
|
||||||
|
self.set_compute_sms(self.total_sms)
|
||||||
|
|
||||||
|
|
||||||
class UBatchWrapper:
|
class UBatchWrapper:
|
||||||
|
|
||||||
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||||
@ -56,6 +101,35 @@ class UBatchWrapper:
|
|||||||
runnable, vllm_config, runtime_mode=runtime_mode)
|
runnable, vllm_config, runtime_mode=runtime_mode)
|
||||||
self.graph_pool = current_platform.get_global_graph_pool()
|
self.graph_pool = current_platform.get_global_graph_pool()
|
||||||
|
|
||||||
|
self.sm_control = self._create_sm_control_context(vllm_config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_sm_control_context(vllm_config: VllmConfig):
|
||||||
|
comm_sms = envs.VLLM_DBO_COMM_SMS
|
||||||
|
|
||||||
|
set_comm_sms = lambda sms: None
|
||||||
|
if vllm_config.parallel_config.enable_expert_parallel:
|
||||||
|
# Currently only DeepEP highthroughput supports SM control so this
|
||||||
|
# only affects that case.
|
||||||
|
all2all_manager = get_ep_group(
|
||||||
|
).device_communicator.all2all_manager
|
||||||
|
|
||||||
|
if all2all_manager.max_sms_used() is not None:
|
||||||
|
comm_sms = min(comm_sms, all2all_manager.max_sms_used())
|
||||||
|
|
||||||
|
if comm_sms > 0:
|
||||||
|
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
|
||||||
|
|
||||||
|
# TODO(lucas): support other kernels besides DeepGEMM
|
||||||
|
set_compute_sms = lambda sms: None
|
||||||
|
if has_deep_gemm() and comm_sms > 0:
|
||||||
|
import deep_gemm as dg
|
||||||
|
set_compute_sms = lambda sms: dg.set_num_sms(sms)
|
||||||
|
|
||||||
|
return SMControlContextManager(comm_sms=comm_sms,
|
||||||
|
set_comm_sms=set_comm_sms,
|
||||||
|
set_compute_sms=set_compute_sms)
|
||||||
|
|
||||||
def __getattr__(self, key: str):
|
def __getattr__(self, key: str):
|
||||||
# allow accessing the attributes of the runnable.
|
# allow accessing the attributes of the runnable.
|
||||||
if hasattr(self.runnable, key):
|
if hasattr(self.runnable, key):
|
||||||
@ -282,8 +356,8 @@ class UBatchWrapper:
|
|||||||
dp_metadata=dp_metadata,
|
dp_metadata=dp_metadata,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||||
|
with self.sm_control:
|
||||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||||
elif num_tokens in self.cudagraphs:
|
elif num_tokens in self.cudagraphs:
|
||||||
cudagraph_metadata = self.cudagraphs[num_tokens]
|
cudagraph_metadata = self.cudagraphs[num_tokens]
|
||||||
cudagraph_metadata.cudagraph.replay()
|
cudagraph_metadata.cudagraph.replay()
|
||||||
@ -300,4 +374,5 @@ class UBatchWrapper:
|
|||||||
dp_metadata=dp_metadata,
|
dp_metadata=dp_metadata,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||||
return self._run_ubatches(ubatch_metadata, self.model)
|
with self.sm_control:
|
||||||
|
return self._run_ubatches(ubatch_metadata, self.model)
|
||||||
|
|||||||
@ -3,9 +3,10 @@
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import ParallelConfig, VllmConfig
|
||||||
from vllm.forward_context import DPMetadata
|
from vllm.forward_context import DPMetadata
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import round_up
|
from vllm.utils import round_up
|
||||||
@ -29,6 +30,16 @@ def should_ubatch_with_num_tokens(
|
|||||||
dp_size, dp_rank)
|
dp_size, dp_rank)
|
||||||
|
|
||||||
|
|
||||||
|
def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int,
|
||||||
|
uniform_decode: bool) -> bool:
|
||||||
|
if not config.enable_dbo:
|
||||||
|
return False
|
||||||
|
if uniform_decode:
|
||||||
|
return num_tokens >= config.dbo_decode_token_threshold
|
||||||
|
else:
|
||||||
|
return num_tokens >= config.dbo_prefill_token_threshold
|
||||||
|
|
||||||
|
|
||||||
def get_dp_padding_ubatch(
|
def get_dp_padding_ubatch(
|
||||||
num_tokens_unpadded: int, num_tokens_padded: int,
|
num_tokens_unpadded: int, num_tokens_padded: int,
|
||||||
should_attempt_ubatching: bool,
|
should_attempt_ubatching: bool,
|
||||||
@ -95,9 +106,37 @@ def get_dp_padding_ubatch(
|
|||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
return should_ubatch, num_tokens_after_padding
|
return should_ubatch, num_tokens_after_padding
|
||||||
|
|
||||||
|
def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \
|
||||||
|
-> UBatchSlices:
|
||||||
|
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||||
|
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||||
|
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||||
|
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
||||||
|
|
||||||
|
first_ubatch_token_slice = slice(0, split_point)
|
||||||
|
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
|
||||||
|
|
||||||
|
# Determine request slices using exclusive stop semantics
|
||||||
|
# First ubatch includes requests whose tokens overlap [0, split_point)
|
||||||
|
first_ubatch_req_stop = int(
|
||||||
|
np.searchsorted(cu_num_tokens, split_point, side="left"))
|
||||||
|
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
|
||||||
|
|
||||||
|
# Second ubatch starts at the request that contains the split_point
|
||||||
|
# or the request starting exactly at split_point (if on boundary)
|
||||||
|
second_ubatch_req_start = int(
|
||||||
|
np.searchsorted(cu_num_tokens, split_point, side="right") - 1)
|
||||||
|
second_ubatch_req_slice = slice(second_ubatch_req_start,
|
||||||
|
len(cu_num_tokens) - 1)
|
||||||
|
|
||||||
|
return [
|
||||||
|
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
||||||
|
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def ubatch_split(
|
def ubatch_split(
|
||||||
max_num_scheduled_tokens: int,
|
num_scheduled_tokens_per_request: np.ndarray,
|
||||||
num_tokens_unpadded: int,
|
num_tokens_unpadded: int,
|
||||||
num_tokens_padded: int,
|
num_tokens_padded: int,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
@ -122,17 +161,20 @@ def ubatch_split(
|
|||||||
return (None, None)
|
return (None, None)
|
||||||
|
|
||||||
# Check preconditions for microbatching
|
# Check preconditions for microbatching
|
||||||
should_attempt_ubatching = \
|
should_attempt_ubatching = check_ubatch_thresholds(
|
||||||
parallel_config.enable_dbo and \
|
parallel_config,
|
||||||
num_tokens_unpadded >= \
|
num_tokens_unpadded,
|
||||||
parallel_config.dbo_decode_token_threshold \
|
vllm_config,
|
||||||
and max_num_scheduled_tokens == 1
|
)
|
||||||
|
|
||||||
# Don't microbatch unless every other DP worker is also microbatching
|
# Don't microbatch unless every other DP worker is also microbatching
|
||||||
num_tokens_after_padding = None
|
should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch(
|
||||||
(should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch(
|
num_tokens_unpadded,
|
||||||
num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching,
|
num_tokens_padded,
|
||||||
vllm_config)
|
should_attempt_ubatching,
|
||||||
|
vllm_config,
|
||||||
|
)
|
||||||
|
|
||||||
if not should_ubatch:
|
if not should_ubatch:
|
||||||
return (None, None)
|
return (None, None)
|
||||||
|
|
||||||
@ -141,15 +183,9 @@ def ubatch_split(
|
|||||||
# to the second ubatch in pad_out_ubatch_slice after attention
|
# to the second ubatch in pad_out_ubatch_slice after attention
|
||||||
# metadata creation
|
# metadata creation
|
||||||
assert num_tokens_after_padding is not None
|
assert num_tokens_after_padding is not None
|
||||||
total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item())
|
token_split_point = int(num_tokens_after_padding[0].item())
|
||||||
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
|
|
||||||
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
|
|
||||||
num_tokens_unpadded)
|
|
||||||
|
|
||||||
# Note there's an assumption here that there's 1 token per request
|
ubatch_slices = create_ubatch_slices(num_scheduled_tokens_per_request,
|
||||||
ubatch_slices = [
|
token_split_point)
|
||||||
UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice),
|
|
||||||
UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice)
|
|
||||||
]
|
|
||||||
|
|
||||||
return (ubatch_slices, num_tokens_after_padding)
|
return (ubatch_slices, num_tokens_after_padding)
|
||||||
|
|||||||
@ -10,6 +10,14 @@ class UBatchSlice:
|
|||||||
request_slice: slice
|
request_slice: slice
|
||||||
token_slice: slice
|
token_slice: slice
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return self.request_slice.start == self.request_slice.stop \
|
||||||
|
or self.token_slice.start == self.token_slice.stop
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_tokens(self) -> int:
|
||||||
|
return self.token_slice.stop - self.token_slice.start
|
||||||
|
|
||||||
|
|
||||||
UBatchSlices: TypeAlias = list[UBatchSlice]
|
UBatchSlices: TypeAlias = list[UBatchSlice]
|
||||||
|
|
||||||
|
|||||||
@ -51,8 +51,8 @@ class UBatchContext:
|
|||||||
self.cpu_wait_event.wait()
|
self.cpu_wait_event.wait()
|
||||||
self.cpu_wait_event.clear()
|
self.cpu_wait_event.clear()
|
||||||
self._restore_context()
|
self._restore_context()
|
||||||
# Assume we start on the compute stream
|
# Assume we want to start on the compute stream
|
||||||
assert current_stream() == self.compute_stream
|
self.update_stream(self.compute_stream)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
@ -62,17 +62,15 @@ class UBatchContext:
|
|||||||
self.maybe_run_recv_hook()
|
self.maybe_run_recv_hook()
|
||||||
self.cpu_signal_event.set()
|
self.cpu_signal_event.set()
|
||||||
self.cpu_wait_event.clear()
|
self.cpu_wait_event.clear()
|
||||||
self.current_stream = self.compute_stream
|
|
||||||
torch.cuda.set_stream(self.current_stream)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _restore_context(self):
|
def _restore_context(self):
|
||||||
forward_context._forward_context = self.forward_context
|
forward_context._forward_context = self.forward_context
|
||||||
torch.cuda.set_stream(self.current_stream)
|
|
||||||
|
|
||||||
def update_stream(self, stream):
|
def update_stream(self, stream):
|
||||||
self.current_stream = stream
|
self.current_stream = stream
|
||||||
torch.cuda.set_stream(self.current_stream)
|
if current_stream() != self.current_stream:
|
||||||
|
torch.cuda.set_stream(self.current_stream)
|
||||||
|
|
||||||
def _signal_comm_done(self):
|
def _signal_comm_done(self):
|
||||||
self.gpu_comm_done_event.record(self.comm_stream)
|
self.gpu_comm_done_event.record(self.comm_stream)
|
||||||
@ -99,9 +97,20 @@ class UBatchContext:
|
|||||||
self.cpu_wait_event.clear()
|
self.cpu_wait_event.clear()
|
||||||
self._restore_context()
|
self._restore_context()
|
||||||
|
|
||||||
|
def switch_to_comm(self):
|
||||||
|
self.update_stream(self.comm_stream)
|
||||||
|
|
||||||
|
def switch_to_compute(self):
|
||||||
|
self.update_stream(self.compute_stream)
|
||||||
|
|
||||||
def switch_to_comm_sync(self):
|
def switch_to_comm_sync(self):
|
||||||
self._signal_compute_done()
|
self._signal_compute_done()
|
||||||
self.update_stream(self.comm_stream)
|
self.update_stream(self.comm_stream)
|
||||||
|
self._wait_compute_done()
|
||||||
|
|
||||||
|
def switch_to_compute_sync(self):
|
||||||
|
self._signal_comm_done()
|
||||||
|
self.update_stream(self.compute_stream)
|
||||||
self._wait_comm_done()
|
self._wait_comm_done()
|
||||||
|
|
||||||
def maybe_run_recv_hook(self):
|
def maybe_run_recv_hook(self):
|
||||||
@ -112,8 +121,7 @@ class UBatchContext:
|
|||||||
def yield_(self):
|
def yield_(self):
|
||||||
self.current_stream = current_stream()
|
self.current_stream = current_stream()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
if self.current_stream != current_stream():
|
self.update_stream(self.current_stream)
|
||||||
self.update_stream(self.current_stream)
|
|
||||||
|
|
||||||
def yield_and_switch_from_compute_to_comm(self):
|
def yield_and_switch_from_compute_to_comm(self):
|
||||||
assert current_stream() == self.compute_stream
|
assert current_stream() == self.compute_stream
|
||||||
@ -153,15 +161,20 @@ def _register_ubatch_function(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
dbo_maybe_run_recv_hook = _register_ubatch_function(
|
||||||
|
UBatchContext.maybe_run_recv_hook)
|
||||||
|
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
|
||||||
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
|
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
|
||||||
UBatchContext.yield_and_switch_from_compute_to_comm)
|
UBatchContext.yield_and_switch_from_compute_to_comm)
|
||||||
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
|
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
|
||||||
UBatchContext.yield_and_switch_from_comm_to_compute)
|
UBatchContext.yield_and_switch_from_comm_to_compute)
|
||||||
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
|
dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm)
|
||||||
dbo_maybe_run_recv_hook = _register_ubatch_function(
|
dbo_switch_to_compute = _register_ubatch_function(
|
||||||
UBatchContext.maybe_run_recv_hook)
|
UBatchContext.switch_to_compute)
|
||||||
dbo_switch_to_comm_sync = _register_ubatch_function(
|
dbo_switch_to_comm_sync = _register_ubatch_function(
|
||||||
UBatchContext.switch_to_comm_sync)
|
UBatchContext.switch_to_comm_sync)
|
||||||
|
dbo_switch_to_compute_sync = _register_ubatch_function(
|
||||||
|
UBatchContext.switch_to_compute_sync)
|
||||||
|
|
||||||
|
|
||||||
def dbo_register_recv_hook(recv_hook):
|
def dbo_register_recv_hook(recv_hook):
|
||||||
|
|||||||
@ -130,15 +130,32 @@ class MultiModalBudget:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AttentionGroup:
|
class AttentionGroup:
|
||||||
backend: type[AttentionBackend]
|
backend: type[AttentionBackend]
|
||||||
|
# When ubatching is enabled we will have a metadata builder for each ubatch
|
||||||
|
# so that if they use internal persistant buffers for cudagraphs, and they
|
||||||
|
# won't have to worry about conflicting with the other ubatches.
|
||||||
metadata_builders: list[AttentionMetadataBuilder]
|
metadata_builders: list[AttentionMetadataBuilder]
|
||||||
layer_names: list[str]
|
layer_names: list[str]
|
||||||
kv_cache_spec: KVCacheSpec
|
kv_cache_spec: KVCacheSpec
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_with_metadata_builders(
|
||||||
|
backend: type[AttentionBackend],
|
||||||
|
layer_names: list[str],
|
||||||
|
kv_cache_spec: KVCacheSpec,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
|
num_metadata_builders: int = 1,
|
||||||
|
) -> 'AttentionGroup':
|
||||||
|
metadata_builders = [
|
||||||
|
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config,
|
||||||
|
device)
|
||||||
|
for _ in range(num_metadata_builders)
|
||||||
|
]
|
||||||
|
return AttentionGroup(backend, metadata_builders, layer_names,
|
||||||
|
kv_cache_spec)
|
||||||
|
|
||||||
def get_metadata_builder(self,
|
def get_metadata_builder(self,
|
||||||
ubatch_id: Optional[int] = None
|
ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||||
) -> AttentionMetadataBuilder:
|
|
||||||
if ubatch_id is None:
|
|
||||||
return self.metadata_builders[0]
|
|
||||||
assert len(self.metadata_builders) > ubatch_id
|
assert len(self.metadata_builders) > ubatch_id
|
||||||
return self.metadata_builders[ubatch_id]
|
return self.metadata_builders[ubatch_id]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user