mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 09:15:55 +08:00
[Core/DBO][1/N] Add Dual-Batch Overlap mechanism to VLLM (#23693)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
08369289af
commit
567939953b
@ -87,6 +87,11 @@ def parse_args():
|
|||||||
default=0.8,
|
default=0.8,
|
||||||
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
|
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-dbo",
|
||||||
|
action="store_true",
|
||||||
|
help=("Enable microbatched execution"),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--compilation-config",
|
"--compilation-config",
|
||||||
type=int,
|
type=int,
|
||||||
@ -113,6 +118,7 @@ def main(
|
|||||||
max_model_len,
|
max_model_len,
|
||||||
compilation_config,
|
compilation_config,
|
||||||
gpu_memory_utilization,
|
gpu_memory_utilization,
|
||||||
|
enable_dbo,
|
||||||
quantization,
|
quantization,
|
||||||
):
|
):
|
||||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||||
@ -167,6 +173,7 @@ def main(
|
|||||||
max_num_seqs=max_num_seqs,
|
max_num_seqs=max_num_seqs,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
enable_dbo=enable_dbo,
|
||||||
quantization=quantization,
|
quantization=quantization,
|
||||||
compilation_config=compilation_config,
|
compilation_config=compilation_config,
|
||||||
)
|
)
|
||||||
@ -227,6 +234,7 @@ if __name__ == "__main__":
|
|||||||
args.max_model_len,
|
args.max_model_len,
|
||||||
args.compilation_config,
|
args.compilation_config,
|
||||||
args.gpu_memory_utilization,
|
args.gpu_memory_utilization,
|
||||||
|
args.enable_dbo,
|
||||||
args.quantization,
|
args.quantization,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
||||||
from tests.v1.attention.utils import create_common_attn_metadata
|
from tests.v1.attention.utils import create_common_attn_metadata
|
||||||
from vllm.v1.attention.backends.utils import (UbatchSlice,
|
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
||||||
_make_metadata_with_slice,
|
_make_metadata_with_slice,
|
||||||
slice_query_start_locs,
|
slice_query_start_locs,
|
||||||
split_attn_metadata)
|
split_attn_metadata)
|
||||||
@ -106,7 +106,7 @@ def mixed_small_metadata():
|
|||||||
def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
|
def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
|
||||||
"""Test slicing decode batch metadata"""
|
"""Test slicing decode batch metadata"""
|
||||||
# Split first request only
|
# Split first request only
|
||||||
ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1))
|
ubatch_slice = UBatchSlice(slice(0, 1), slice(0, 1))
|
||||||
|
|
||||||
result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata)
|
result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata)
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
|
|||||||
|
|
||||||
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
|
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
|
||||||
"""Test slicing mixed batch metadata"""
|
"""Test slicing mixed batch metadata"""
|
||||||
ubatch_slice = UbatchSlice(slice(1, 3),
|
ubatch_slice = UBatchSlice(slice(1, 3),
|
||||||
slice(1, 7)) # Requests 1-3, tokens 1-7
|
slice(1, 7)) # Requests 1-3, tokens 1-7
|
||||||
|
|
||||||
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
|
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
|
||||||
@ -137,8 +137,8 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
|||||||
num_tokens = large_decode_metadata.num_reqs
|
num_tokens = large_decode_metadata.num_reqs
|
||||||
mid_point = num_tokens // 2
|
mid_point = num_tokens // 2
|
||||||
ubatch_slices = [
|
ubatch_slices = [
|
||||||
UbatchSlice(slice(0, mid_point), slice(0, mid_point)),
|
UBatchSlice(slice(0, mid_point), slice(0, mid_point)),
|
||||||
UbatchSlice(slice(mid_point, num_tokens), slice(mid_point,
|
UBatchSlice(slice(mid_point, num_tokens), slice(mid_point,
|
||||||
num_tokens)),
|
num_tokens)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -365,7 +365,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
|||||||
# Mock runner for attention metadata building
|
# Mock runner for attention metadata building
|
||||||
proposer.runner = mock.MagicMock()
|
proposer.runner = mock.MagicMock()
|
||||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||||
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
|
proposer.runner.attn_groups[0][0].metadata_builders = [
|
||||||
|
attn_metadata_builder
|
||||||
|
]
|
||||||
|
|
||||||
result = proposer.propose(target_token_ids=target_token_ids,
|
result = proposer.propose(target_token_ids=target_token_ids,
|
||||||
target_positions=target_positions,
|
target_positions=target_positions,
|
||||||
@ -489,7 +491,9 @@ def test_propose_tree(spec_token_tree):
|
|||||||
# Mock runner for attention metadata building.
|
# Mock runner for attention metadata building.
|
||||||
proposer.runner = mock.MagicMock()
|
proposer.runner = mock.MagicMock()
|
||||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||||
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
|
proposer.runner.attn_groups[0][0].metadata_builders = [
|
||||||
|
attn_metadata_builder
|
||||||
|
]
|
||||||
|
|
||||||
# Setup inputs for the proposer.
|
# Setup inputs for the proposer.
|
||||||
target_token_ids = torch.randint(0,
|
target_token_ids = torch.randint(0,
|
||||||
|
|||||||
@ -2848,6 +2848,14 @@ class VllmConfig:
|
|||||||
"when cudagraph_mode piecewise cudagraphs is used, "\
|
"when cudagraph_mode piecewise cudagraphs is used, "\
|
||||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||||
|
|
||||||
|
if self.parallel_config.enable_dbo:
|
||||||
|
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||||
|
assert a2a_backend == "deepep_low_latency", \
|
||||||
|
"Microbatching currently only supports the deepep_low_latency "\
|
||||||
|
f"all2all backend. {a2a_backend} is not supported. To fix set "\
|
||||||
|
"the VLLM_ALL2ALL_BACKEND environment variable to "\
|
||||||
|
"deepep_low_latency and install the DeepEP kerenls."
|
||||||
|
|
||||||
if not self.instance_id:
|
if not self.instance_id:
|
||||||
self.instance_id = random_uuid()[:5]
|
self.instance_id = random_uuid()[:5]
|
||||||
|
|
||||||
|
|||||||
@ -137,6 +137,14 @@ class ParallelConfig:
|
|||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||||
|
|
||||||
|
enable_dbo: bool = False
|
||||||
|
"""Enable microbatching for the model executor."""
|
||||||
|
|
||||||
|
dbo_decode_token_threshold: int = 32
|
||||||
|
"""The threshold for microbatching. If the number of tokens in the
|
||||||
|
request is greater than this threshold, microbatching will be used.
|
||||||
|
Otherwise, the request will be processed in a single batch."""
|
||||||
|
|
||||||
ray_workers_use_nsight: bool = False
|
ray_workers_use_nsight: bool = False
|
||||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||||
|
|
||||||
|
|||||||
@ -251,9 +251,4 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||||
buffer_kwargs, deep_ep.Buffer)
|
buffer_kwargs, deep_ep.Buffer)
|
||||||
# It is dangerous to set num sms outside this function. num_sms is not
|
|
||||||
# a part of the hash-key that identifies this object. If we are in a
|
|
||||||
# situation where we make objects with different num_sms, the hash key
|
|
||||||
# in get_or_create must be updated.
|
|
||||||
handle.set_num_sms(self.num_sms)
|
|
||||||
return handle
|
return handle
|
||||||
|
|||||||
@ -327,6 +327,9 @@ class EngineArgs:
|
|||||||
data_parallel_hybrid_lb: bool = False
|
data_parallel_hybrid_lb: bool = False
|
||||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||||
|
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||||
|
dbo_decode_token_threshold: int = \
|
||||||
|
ParallelConfig.dbo_decode_token_threshold
|
||||||
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
|
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
|
||||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||||
expert_placement_strategy: ExpertPlacementStrategy = \
|
expert_placement_strategy: ExpertPlacementStrategy = \
|
||||||
@ -695,6 +698,11 @@ class EngineArgs:
|
|||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--enable-expert-parallel",
|
"--enable-expert-parallel",
|
||||||
**parallel_kwargs["enable_expert_parallel"])
|
**parallel_kwargs["enable_expert_parallel"])
|
||||||
|
parallel_group.add_argument("--enable-dbo",
|
||||||
|
**parallel_kwargs["enable_dbo"])
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--dbo-decode-token-threshold",
|
||||||
|
**parallel_kwargs["dbo_decode_token_threshold"])
|
||||||
parallel_group.add_argument("--enable-eplb",
|
parallel_group.add_argument("--enable-eplb",
|
||||||
**parallel_kwargs["enable_eplb"])
|
**parallel_kwargs["enable_eplb"])
|
||||||
parallel_group.add_argument("--eplb-config",
|
parallel_group.add_argument("--eplb-config",
|
||||||
@ -1339,6 +1347,8 @@ class EngineArgs:
|
|||||||
data_parallel_backend=self.data_parallel_backend,
|
data_parallel_backend=self.data_parallel_backend,
|
||||||
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
|
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
|
||||||
enable_expert_parallel=self.enable_expert_parallel,
|
enable_expert_parallel=self.enable_expert_parallel,
|
||||||
|
enable_dbo=self.enable_dbo,
|
||||||
|
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
eplb_config=self.eplb_config,
|
eplb_config=self.eplb_config,
|
||||||
expert_placement_strategy=self.expert_placement_strategy,
|
expert_placement_strategy=self.expert_placement_strategy,
|
||||||
|
|||||||
@ -14,6 +14,7 @@ import vllm.envs as envs
|
|||||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
@ -97,6 +98,53 @@ class DPMetadata:
|
|||||||
dist.all_reduce(num_tokens_tensor, group=group)
|
dist.all_reduce(num_tokens_tensor, group=group)
|
||||||
return num_tokens_tensor.cpu()
|
return num_tokens_tensor.cpu()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def should_ubatch_across_dp(
|
||||||
|
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
|
||||||
|
padded_num_tokens_per_ubatch: int, dp_size: int,
|
||||||
|
dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||||
|
run with microbatching or none of them do. If this function decides
|
||||||
|
not to run with microbatching. It will "abort" meaning that no padding
|
||||||
|
information will be returned to the caller. It will return (False, None)
|
||||||
|
|
||||||
|
2. Determines the total number of tokens that each rank will run.
|
||||||
|
All ranks will be padded out so that the run with the same number
|
||||||
|
of tokens
|
||||||
|
|
||||||
|
Returns: tuple[
|
||||||
|
should_ubatch: Are all DP ranks going to microbatch
|
||||||
|
num_tokens_after_padding: A tensor containing the total number of
|
||||||
|
tokens per-microbatch for each DP rank including padding. Will be
|
||||||
|
None if should_ubatch if False
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
device = current_platform.device_type
|
||||||
|
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
|
||||||
|
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
|
||||||
|
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
|
||||||
|
tensor[2][dp_rank] = 1 if should_ubatch else 0
|
||||||
|
|
||||||
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
|
dist.all_reduce(tensor, group=get_dp_group().device_group)
|
||||||
|
|
||||||
|
result: bool = bool(torch.all(tensor[2] == 1).item())
|
||||||
|
if not result:
|
||||||
|
return result, None
|
||||||
|
|
||||||
|
orig_num_tokens_tensor = tensor[0, :]
|
||||||
|
padded_num_tokens_tensor = tensor[1, :]
|
||||||
|
|
||||||
|
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
|
||||||
|
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
|
||||||
|
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
|
||||||
|
logger.debug("Aborting ubatching %s %s", orig_min_num_tokens,
|
||||||
|
padded_max_num_tokens)
|
||||||
|
return False, None
|
||||||
|
return result, padded_num_tokens_tensor.cpu()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(
|
def make(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
@ -119,14 +167,15 @@ class DPMetadata:
|
|||||||
|
|
||||||
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
||||||
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
||||||
assert (num_tokens_across_dp is None
|
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
|
||||||
or num_tokens_across_dp[dp_rank] == batchsize)
|
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
|
||||||
if num_tokens_across_dp is None:
|
if num_tokens_across_dp is None:
|
||||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||||
batchsize, dp_size, dp_rank)
|
batchsize, dp_size, dp_rank)
|
||||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
|
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
|
||||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
||||||
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
|
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
|
||||||
|
num_tokens_across_dp)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
|
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
|
||||||
@ -179,9 +228,12 @@ class ForwardContext:
|
|||||||
Type AttentionMetadata for v0,
|
Type AttentionMetadata for v0,
|
||||||
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
|
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
|
||||||
attention layer to its attention metadata
|
attention layer to its attention metadata
|
||||||
set dynamically for each forward pass
|
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
|
||||||
|
for each microbatch.
|
||||||
|
Set dynamically for each forward pass
|
||||||
"""
|
"""
|
||||||
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
|
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"],
|
||||||
|
list[dict[str, "AttentionMetadata"]]]
|
||||||
# TODO: remove after making all virtual_engines share the same kv cache
|
# TODO: remove after making all virtual_engines share the same kv cache
|
||||||
virtual_engine: int # set dynamically for each forward pass
|
virtual_engine: int # set dynamically for each forward pass
|
||||||
# set dynamically for each forward pass
|
# set dynamically for each forward pass
|
||||||
@ -191,6 +243,8 @@ class ForwardContext:
|
|||||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
|
||||||
batch_descriptor: Optional[BatchDescriptor] = None
|
batch_descriptor: Optional[BatchDescriptor] = None
|
||||||
|
|
||||||
|
ubatch_slices: Optional[UBatchSlices] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert self.cudagraph_runtime_mode in [
|
assert self.cudagraph_runtime_mode in [
|
||||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
||||||
@ -208,6 +262,39 @@ def get_forward_context() -> ForwardContext:
|
|||||||
return _forward_context
|
return _forward_context
|
||||||
|
|
||||||
|
|
||||||
|
def create_forward_context(
|
||||||
|
attn_metadata: Any,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
virtual_engine: int = 0,
|
||||||
|
dp_metadata: Optional[DPMetadata] = None,
|
||||||
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
|
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||||
|
ubatch_slices: Optional[UBatchSlices] = None):
|
||||||
|
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
|
||||||
|
static_forward_context,
|
||||||
|
virtual_engine=virtual_engine,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
dp_metadata=dp_metadata,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=batch_descriptor,
|
||||||
|
ubatch_slices=ubatch_slices)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def override_forward_context(forward_context: Optional[ForwardContext]):
|
||||||
|
"""A context manager that overrides the current forward context.
|
||||||
|
This is used to override the forward context for a specific
|
||||||
|
forward pass.
|
||||||
|
"""
|
||||||
|
global _forward_context
|
||||||
|
prev_context = _forward_context
|
||||||
|
_forward_context = forward_context
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_forward_context = prev_context
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_forward_context(
|
def set_forward_context(
|
||||||
attn_metadata: Any,
|
attn_metadata: Any,
|
||||||
@ -216,7 +303,8 @@ def set_forward_context(
|
|||||||
num_tokens: Optional[int] = None,
|
num_tokens: Optional[int] = None,
|
||||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor: Optional[BatchDescriptor] = None):
|
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||||
|
ubatch_slices: Optional[UBatchSlices] = None):
|
||||||
"""A context manager that stores the current forward context,
|
"""A context manager that stores the current forward context,
|
||||||
can be attention metadata, etc.
|
can be attention metadata, etc.
|
||||||
Here we can inject common logic for every model forward pass.
|
Here we can inject common logic for every model forward pass.
|
||||||
@ -225,6 +313,7 @@ def set_forward_context(
|
|||||||
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
||||||
if need_to_track_batchsize:
|
if need_to_track_batchsize:
|
||||||
forward_start_time = time.perf_counter()
|
forward_start_time = time.perf_counter()
|
||||||
|
|
||||||
dp_metadata: Optional[DPMetadata] = None
|
dp_metadata: Optional[DPMetadata] = None
|
||||||
if vllm_config.parallel_config.data_parallel_size > 1 and (
|
if vllm_config.parallel_config.data_parallel_size > 1 and (
|
||||||
attn_metadata is not None or num_tokens is not None):
|
attn_metadata is not None or num_tokens is not None):
|
||||||
@ -232,20 +321,14 @@ def set_forward_context(
|
|||||||
attn_metadata, num_tokens or 0,
|
attn_metadata, num_tokens or 0,
|
||||||
num_tokens_across_dp)
|
num_tokens_across_dp)
|
||||||
|
|
||||||
global _forward_context
|
forward_context = create_forward_context(attn_metadata, vllm_config,
|
||||||
prev_context = _forward_context
|
virtual_engine, dp_metadata,
|
||||||
_forward_context = ForwardContext(
|
cudagraph_runtime_mode,
|
||||||
no_compile_layers=vllm_config.compilation_config.
|
batch_descriptor, ubatch_slices)
|
||||||
static_forward_context,
|
|
||||||
virtual_engine=virtual_engine,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
dp_metadata=dp_metadata,
|
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
|
||||||
batch_descriptor=batch_descriptor,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
with override_forward_context(forward_context):
|
||||||
|
yield
|
||||||
finally:
|
finally:
|
||||||
global last_logging_time, batchsize_logging_interval
|
global last_logging_time, batchsize_logging_interval
|
||||||
if need_to_track_batchsize:
|
if need_to_track_batchsize:
|
||||||
@ -282,5 +365,3 @@ def set_forward_context(
|
|||||||
logger.info(("Batchsize forward time stats "
|
logger.info(("Batchsize forward time stats "
|
||||||
"(batchsize, count, median_time(ms)): %s"),
|
"(batchsize, count, median_time(ms)): %s"),
|
||||||
forward_stats)
|
forward_stats)
|
||||||
|
|
||||||
_forward_context = prev_context
|
|
||||||
|
|||||||
@ -191,7 +191,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> Callable:
|
) -> tuple[Callable, mk.ReceiverType]:
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
topk = topk_ids.size(1)
|
topk = topk_ids.size(1)
|
||||||
@ -217,13 +217,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1q_scale = None
|
a1q_scale = None
|
||||||
a1_post_scale = a1_scale
|
a1_post_scale = a1_scale
|
||||||
|
|
||||||
return self._do_dispatch(tokens=a1q,
|
return (lambda *args: None,
|
||||||
token_scales=a1q_scale,
|
self._do_dispatch(tokens=a1q,
|
||||||
rank_topk_ids=topk_ids,
|
token_scales=a1q_scale,
|
||||||
rank_topk_weights=topk_weights,
|
rank_topk_ids=topk_ids,
|
||||||
num_experts=num_experts,
|
rank_topk_weights=topk_weights,
|
||||||
a1_scale=a1_post_scale,
|
num_experts=num_experts,
|
||||||
quant_config=quant_config)
|
a1_scale=a1_post_scale,
|
||||||
|
quant_config=quant_config))
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
@ -237,10 +238,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.PrepareResultType:
|
) -> mk.PrepareResultType:
|
||||||
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
|
(_, receiver) = self.prepare_async(a1, a1_scale, a2_scale,
|
||||||
topk_ids, num_experts, expert_map,
|
topk_weights, topk_ids, num_experts,
|
||||||
apply_router_weight_on_input,
|
expert_map,
|
||||||
quant_config)
|
apply_router_weight_on_input,
|
||||||
|
quant_config)
|
||||||
return receiver()
|
return receiver()
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
|
|||||||
@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
|||||||
TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceDelegate)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
moe_kernel_quantize_input, normalize_batched_scales_shape)
|
moe_kernel_quantize_input, normalize_batched_scales_shape)
|
||||||
|
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
|
||||||
|
dbo_maybe_run_recv_hook,
|
||||||
|
dbo_register_recv_hook, dbo_yield)
|
||||||
|
|
||||||
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
||||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||||
@ -55,7 +58,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
# The dispatch function returns a handle that the combine function
|
# The dispatch function returns a handle that the combine function
|
||||||
# requires. We store the handle here so it is available to the
|
# requires. We store the handle here so it is available to the
|
||||||
# combine function.
|
# combine function.
|
||||||
self.handle = None
|
self.handles: list[Optional[tuple]] = [None, None]
|
||||||
self.num_dispatchers_ = num_dispatchers
|
self.num_dispatchers_ = num_dispatchers
|
||||||
|
|
||||||
def num_dispatchers(self) -> int:
|
def num_dispatchers(self) -> int:
|
||||||
@ -123,13 +126,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.ReceiverType:
|
) -> tuple[Callable, mk.ReceiverType]:
|
||||||
|
|
||||||
hidden_size = a1.size(1)
|
hidden_size = a1.size(1)
|
||||||
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
|
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
|
||||||
(f"Hidden Size {hidden_size} not in supported list of hidden sizes"
|
(f"Hidden Size {hidden_size} not in supported list of hidden sizes"
|
||||||
f"{self.SUPPORTED_HIDDEN_SIZES}")
|
f"{self.SUPPORTED_HIDDEN_SIZES}")
|
||||||
|
|
||||||
|
a2a_idx = dbo_current_ubatch_id()
|
||||||
|
|
||||||
if self.use_fp8_dispatch:
|
if self.use_fp8_dispatch:
|
||||||
assert hidden_size % 128 == 0, \
|
assert hidden_size % 128 == 0, \
|
||||||
"DeepEP kernels quantize the inputs in blocks of shape 128"
|
"DeepEP kernels quantize the inputs in blocks of shape 128"
|
||||||
@ -148,7 +153,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1 = a1 * topk_weights.to(a1.dtype)
|
a1 = a1 * topk_weights.to(a1.dtype)
|
||||||
|
|
||||||
# Dispatch
|
# Dispatch
|
||||||
expert_x, expert_num_tokens, self.handle, event, hook = \
|
expert_x, expert_num_tokens, handle, _, hook= \
|
||||||
self.buffer.low_latency_dispatch(a1,
|
self.buffer.low_latency_dispatch(a1,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
self.max_tokens_per_rank,
|
self.max_tokens_per_rank,
|
||||||
@ -156,21 +161,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
use_fp8=self.use_fp8_dispatch,
|
use_fp8=self.use_fp8_dispatch,
|
||||||
async_finish=False,
|
async_finish=False,
|
||||||
return_recv_hook=True)
|
return_recv_hook=True)
|
||||||
|
self.handles[a2a_idx] = handle
|
||||||
|
|
||||||
return lambda: self._receiver(hook, expert_x, expert_num_tokens,
|
return (hook, lambda: self._receiver(expert_x, expert_num_tokens,
|
||||||
a1_scale, a1.dtype, quant_config)
|
a1_scale, a1.dtype, quant_config))
|
||||||
|
|
||||||
def _receiver(
|
def _receiver(
|
||||||
self,
|
self,
|
||||||
hook: Callable,
|
|
||||||
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||||
expert_num_tokens: torch.Tensor,
|
expert_num_tokens: torch.Tensor,
|
||||||
a1_scale,
|
a1_scale,
|
||||||
a1_dtype,
|
a1_dtype,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.PrepareResultType:
|
) -> mk.PrepareResultType:
|
||||||
hook()
|
|
||||||
|
|
||||||
expert_x, expert_x_scale = self._do_quant(
|
expert_x, expert_x_scale = self._do_quant(
|
||||||
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
|
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
|
||||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||||
@ -192,10 +195,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.PrepareResultType:
|
) -> mk.PrepareResultType:
|
||||||
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
|
hook, receiver = self.prepare_async(a1, a1_scale, a2_scale,
|
||||||
topk_ids, num_experts, expert_map,
|
topk_weights, topk_ids,
|
||||||
apply_router_weight_on_input,
|
num_experts, expert_map,
|
||||||
quant_config)
|
apply_router_weight_on_input,
|
||||||
|
quant_config)
|
||||||
|
hook()
|
||||||
return receiver()
|
return receiver()
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
@ -210,7 +215,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
assert isinstance(
|
assert isinstance(
|
||||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||||
), ("Weight application and reduction happens in the combine kernel.")
|
), ("Weight application and reduction happens in the combine kernel.")
|
||||||
assert self.handle is not None
|
|
||||||
|
a2a_idx = dbo_current_ubatch_id()
|
||||||
|
do_recv_hook = dbo_enabled()
|
||||||
|
handle = self.handles[a2a_idx]
|
||||||
|
assert handle is not None
|
||||||
|
|
||||||
combine_topk_weights = topk_weights
|
combine_topk_weights = topk_weights
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
@ -218,12 +227,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
combine_topk_weights = torch.ones_like(topk_weights)
|
combine_topk_weights = torch.ones_like(topk_weights)
|
||||||
|
|
||||||
# TODO (varun) : Enable zero copy mode
|
# TODO (varun) : Enable zero copy mode
|
||||||
_, event, hook = self.buffer.low_latency_combine(
|
dbo_maybe_run_recv_hook()
|
||||||
|
_, _, recv_hook = self.buffer.low_latency_combine(
|
||||||
fused_expert_output,
|
fused_expert_output,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
combine_topk_weights,
|
combine_topk_weights,
|
||||||
self.handle,
|
handle,
|
||||||
async_finish=False,
|
async_finish=False,
|
||||||
zero_copy=False,
|
zero_copy=False,
|
||||||
return_recv_hook=False,
|
return_recv_hook=do_recv_hook,
|
||||||
out=output)
|
out=output)
|
||||||
|
if recv_hook is not None:
|
||||||
|
dbo_register_recv_hook(recv_hook)
|
||||||
|
dbo_yield()
|
||||||
|
|||||||
@ -38,6 +38,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
|
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
|
||||||
round_up)
|
round_up)
|
||||||
|
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .fused_batched_moe import BatchedTritonExperts
|
from .fused_batched_moe import BatchedTritonExperts
|
||||||
@ -992,16 +993,28 @@ class FusedMoE(CustomOp):
|
|||||||
if (self.moe_parallel_config.use_pplx_kernels
|
if (self.moe_parallel_config.use_pplx_kernels
|
||||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||||
or self.moe_config.use_flashinfer_cutlass_kernels):
|
or self.moe_config.use_flashinfer_cutlass_kernels):
|
||||||
self.batched_hidden_states = torch.zeros(
|
if vllm_config.parallel_config.enable_dbo:
|
||||||
(moe.max_num_tokens, self.hidden_size),
|
self.batched_hidden_states = torch.zeros(
|
||||||
dtype=moe.in_dtype,
|
(2, moe.max_num_tokens, self.hidden_size),
|
||||||
device=torch.cuda.current_device())
|
dtype=moe.in_dtype,
|
||||||
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
# Note here we use `num_experts` which is logical expert count
|
# Note here we use `num_experts` which is logical expert count
|
||||||
self.batched_router_logits = torch.zeros(
|
self.batched_router_logits = torch.zeros(
|
||||||
(moe.max_num_tokens, num_experts),
|
(2, moe.max_num_tokens, num_experts),
|
||||||
dtype=moe.in_dtype,
|
dtype=moe.in_dtype,
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
else:
|
||||||
|
self.batched_hidden_states = torch.zeros(
|
||||||
|
(moe.max_num_tokens, self.hidden_size),
|
||||||
|
dtype=moe.in_dtype,
|
||||||
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
|
# Note here we use `num_experts` which is logical expert count
|
||||||
|
self.batched_router_logits = torch.zeros(
|
||||||
|
(moe.max_num_tokens, num_experts),
|
||||||
|
dtype=moe.in_dtype,
|
||||||
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shared_experts(self) -> Optional[torch.nn.Module]:
|
def shared_experts(self) -> Optional[torch.nn.Module]:
|
||||||
@ -1708,14 +1721,29 @@ class FusedMoE(CustomOp):
|
|||||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||||
|
|
||||||
assert (self.batched_hidden_states.size(0) # type: ignore
|
assert self.batched_hidden_states is not None
|
||||||
|
assert self.batched_router_logits is not None
|
||||||
|
# This is only true when DBO has been enabled in the config.
|
||||||
|
# Both tensors will have an outer dimension for the ubatch id
|
||||||
|
if self.batched_hidden_states.dim() == 3:
|
||||||
|
assert self.batched_router_logits.dim() == 3
|
||||||
|
batch_buffer_idx = dbo_current_ubatch_id()
|
||||||
|
batched_hidden_states = self.batched_hidden_states[
|
||||||
|
batch_buffer_idx, :]
|
||||||
|
batched_router_logits = self.batched_router_logits[
|
||||||
|
batch_buffer_idx, :]
|
||||||
|
else:
|
||||||
|
batched_hidden_states = self.batched_hidden_states
|
||||||
|
batched_router_logits = self.batched_router_logits
|
||||||
|
|
||||||
|
assert (batched_hidden_states.size(0) # type: ignore
|
||||||
>= chunk_size)
|
>= chunk_size)
|
||||||
assert (self.batched_router_logits.size(0) # type: ignore
|
assert (batched_router_logits.size(0) # type: ignore
|
||||||
>= chunk_size)
|
>= chunk_size)
|
||||||
staged_hidden_states = self.batched_hidden_states[:
|
staged_hidden_states = batched_hidden_states[:
|
||||||
chunk_size, :] # type: ignore
|
chunk_size, :] # type: ignore
|
||||||
staged_router_logits = self.batched_router_logits[:
|
staged_router_logits = batched_router_logits[:
|
||||||
chunk_size, :] # type: ignore
|
chunk_size, :] # type: ignore
|
||||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
|||||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||||
_resize_cache, count_expert_num_tokens)
|
_resize_cache, count_expert_num_tokens)
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
|
||||||
|
dbo_register_recv_hook, dbo_yield)
|
||||||
|
|
||||||
#
|
#
|
||||||
# This file defines a set of base classes used to make MoE kernels more modular.
|
# This file defines a set of base classes used to make MoE kernels more modular.
|
||||||
@ -226,7 +228,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> ReceiverType:
|
) -> tuple[Callable, ReceiverType]:
|
||||||
"""
|
"""
|
||||||
Perform any quantization (and/or) dispatching needed for this kernel
|
Perform any quantization (and/or) dispatching needed for this kernel
|
||||||
but do not wait for results from other workers.
|
but do not wait for results from other workers.
|
||||||
@ -496,6 +498,23 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class SharedResizableBuffer:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.buffer = None
|
||||||
|
|
||||||
|
def get(self, shape: tuple[int, ...], device: torch.device,
|
||||||
|
dtype: torch.dtype):
|
||||||
|
shape_numel = prod(shape)
|
||||||
|
if self.buffer is None or self.buffer.numel() < shape_numel:
|
||||||
|
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
|
||||||
|
assert self.buffer.device == device, \
|
||||||
|
f"Buffer device mismatch: {self.buffer.device} != {device}"
|
||||||
|
assert self.buffer.dtype == dtype, \
|
||||||
|
f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}"
|
||||||
|
return self.buffer[:shape_numel].view(*shape)
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class FusedMoEModularKernel(torch.nn.Module):
|
class FusedMoEModularKernel(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -509,6 +528,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
layer due to any layer specific state that may be used by the component
|
layer due to any layer specific state that may be used by the component
|
||||||
objects.
|
objects.
|
||||||
"""
|
"""
|
||||||
|
fused_out_buffer = SharedResizableBuffer()
|
||||||
|
workspace13_buffer = SharedResizableBuffer()
|
||||||
|
workspace2_buffer = SharedResizableBuffer()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -559,12 +581,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
|
|
||||||
# We can reuse the memory between cache1 and cache3 because by the
|
# We can reuse the memory between cache1 and cache3 because by the
|
||||||
# time we need cache3, we're done with cache1.
|
# time we need cache3, we're done with cache1.
|
||||||
workspace13 = torch.empty(prod(workspace13_shape),
|
workspace13 = self.workspace13_buffer.get(workspace13_shape,
|
||||||
device=a1.device,
|
device=a1.device,
|
||||||
dtype=workspace_dtype)
|
dtype=workspace_dtype)
|
||||||
workspace2 = torch.empty(prod(workspace2_shape),
|
workspace2 = self.workspace2_buffer.get(workspace2_shape,
|
||||||
device=a1.device,
|
device=a1.device,
|
||||||
dtype=workspace_dtype)
|
dtype=workspace_dtype)
|
||||||
|
|
||||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||||
@ -656,9 +678,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||||
expert_tokens_meta)
|
expert_tokens_meta)
|
||||||
fused_out = torch.empty(fused_out_shape,
|
fused_out = self.fused_out_buffer.get(fused_out_shape,
|
||||||
device=a1q.device,
|
device=a1q.device,
|
||||||
dtype=a1.dtype)
|
dtype=a1.dtype)
|
||||||
|
|
||||||
def slice_input_tensors(
|
def slice_input_tensors(
|
||||||
chunk_idx: int
|
chunk_idx: int
|
||||||
@ -801,8 +823,10 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
|
|
||||||
shared_output: torch.Tensor
|
shared_output: torch.Tensor
|
||||||
|
|
||||||
if (not self.prepare_finalize.supports_async()
|
if not self.prepare_finalize.supports_async():
|
||||||
or self.shared_experts is None):
|
# We shouldn't be running an a2a kernel that doesn't
|
||||||
|
# support async prepare/finalize
|
||||||
|
assert not dbo_enabled()
|
||||||
|
|
||||||
# Run shared experts serially with dispatch.
|
# Run shared experts serially with dispatch.
|
||||||
if self.shared_experts is not None:
|
if self.shared_experts is not None:
|
||||||
@ -822,7 +846,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Overlap shared expert compute with all2all dispatch.
|
# Overlap shared expert compute with all2all dispatch.
|
||||||
receiver = self.prepare_finalize.prepare_async(
|
dbo_maybe_run_recv_hook()
|
||||||
|
hook, receiver = self.prepare_finalize.prepare_async(
|
||||||
a1,
|
a1,
|
||||||
a1_scale,
|
a1_scale,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
@ -834,8 +859,16 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
self.fused_experts.quant_config,
|
self.fused_experts.quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert self.shared_experts is not None
|
if self.shared_experts is not None:
|
||||||
shared_output = self.shared_experts(a1)
|
shared_output = self.shared_experts(a1)
|
||||||
|
|
||||||
|
# If DBO is being used, register the hook with the ubatch context
|
||||||
|
# and call it in dbo_maybe_run_recv_hook instead of passing it to
|
||||||
|
# the receiver.
|
||||||
|
dbo_register_recv_hook(hook)
|
||||||
|
dbo_yield()
|
||||||
|
if not dbo_enabled():
|
||||||
|
hook()
|
||||||
|
|
||||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||||
_expert_topk_weights) = receiver()
|
_expert_topk_weights) = receiver()
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import pplx_kernels as pplx
|
import pplx_kernels as pplx
|
||||||
import torch
|
import torch
|
||||||
@ -103,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.ReceiverType:
|
) -> tuple[Callable, mk.ReceiverType]:
|
||||||
num_tokens = a1.size(0) # M
|
num_tokens = a1.size(0) # M
|
||||||
hidden_dim = a1.size(-1) # K
|
hidden_dim = a1.size(-1) # K
|
||||||
|
|
||||||
@ -214,30 +214,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
do_recv=False,
|
do_recv=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return lambda: self._receiver(
|
hook = lambda: self.a2a.dispatch(
|
||||||
expert_num_tokens,
|
|
||||||
expert_x,
|
|
||||||
expert_x_scale,
|
|
||||||
a1q,
|
|
||||||
a1q_scale,
|
|
||||||
topk_ids,
|
|
||||||
bound_m,
|
|
||||||
orig_a_scale_block_shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _receiver(
|
|
||||||
self,
|
|
||||||
expert_num_tokens: torch.Tensor,
|
|
||||||
expert_x: torch.Tensor,
|
|
||||||
expert_x_scale: Optional[torch.Tensor],
|
|
||||||
a1q: torch.Tensor,
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
bound_m: Optional[torch.Tensor],
|
|
||||||
orig_a_scale_block_shape: Optional[int],
|
|
||||||
) -> mk.PrepareResultType:
|
|
||||||
|
|
||||||
self.a2a.dispatch(
|
|
||||||
out_expert_num_tokens=expert_num_tokens,
|
out_expert_num_tokens=expert_num_tokens,
|
||||||
out_expert_x=expert_x,
|
out_expert_x=expert_x,
|
||||||
out_expert_x_scale=expert_x_scale,
|
out_expert_x_scale=expert_x_scale,
|
||||||
@ -249,6 +226,21 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
do_recv=True,
|
do_recv=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return (hook, lambda: self._receiver(
|
||||||
|
expert_num_tokens,
|
||||||
|
expert_x,
|
||||||
|
expert_x_scale,
|
||||||
|
orig_a_scale_block_shape,
|
||||||
|
))
|
||||||
|
|
||||||
|
def _receiver(
|
||||||
|
self,
|
||||||
|
expert_num_tokens: torch.Tensor,
|
||||||
|
expert_x: torch.Tensor,
|
||||||
|
expert_x_scale: Optional[torch.Tensor],
|
||||||
|
orig_a_scale_block_shape: Optional[int],
|
||||||
|
) -> mk.PrepareResultType:
|
||||||
|
|
||||||
if expert_x_scale is not None:
|
if expert_x_scale is not None:
|
||||||
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
|
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
|
||||||
assert expert_x_scale.ndim == 3
|
assert expert_x_scale.ndim == 3
|
||||||
@ -270,7 +262,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.PrepareResultType:
|
) -> mk.PrepareResultType:
|
||||||
receiver = self.prepare_async(
|
hook, receiver = self.prepare_async(
|
||||||
a1,
|
a1,
|
||||||
a1_scale,
|
a1_scale,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
@ -281,6 +273,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
quant_config,
|
quant_config,
|
||||||
)
|
)
|
||||||
|
hook()
|
||||||
return receiver()
|
return receiver()
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
|
|||||||
get_kv_connector_cache_layout)
|
get_kv_connector_cache_layout)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
from vllm.v1.worker.ubatch_utils import UBatchSlice
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
KVCacheLayoutType = Literal["NHD", "HND"]
|
KVCacheLayoutType = Literal["NHD", "HND"]
|
||||||
@ -81,12 +82,6 @@ class CommonAttentionMetadata:
|
|||||||
encoder_seq_lens: Optional[np.ndarray] = None
|
encoder_seq_lens: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class UbatchSlice:
|
|
||||||
request_slice: slice
|
|
||||||
token_slice: slice
|
|
||||||
|
|
||||||
|
|
||||||
def slice_query_start_locs(
|
def slice_query_start_locs(
|
||||||
query_start_loc: torch.Tensor,
|
query_start_loc: torch.Tensor,
|
||||||
request_slice: slice,
|
request_slice: slice,
|
||||||
@ -103,7 +98,7 @@ def slice_query_start_locs(
|
|||||||
|
|
||||||
|
|
||||||
def _make_metadata_with_slice(
|
def _make_metadata_with_slice(
|
||||||
ubatch_slice: UbatchSlice,
|
ubatch_slice: UBatchSlice,
|
||||||
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
|
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
|
||||||
"""
|
"""
|
||||||
This function creates a new CommonAttentionMetadata that corresponds to
|
This function creates a new CommonAttentionMetadata that corresponds to
|
||||||
@ -133,6 +128,11 @@ def _make_metadata_with_slice(
|
|||||||
torch.max(torch.abs(query_start_loc_cpu[1:] -
|
torch.max(torch.abs(query_start_loc_cpu[1:] -
|
||||||
query_start_loc_cpu[:-1])).item())
|
query_start_loc_cpu[:-1])).item())
|
||||||
|
|
||||||
|
# This is to account for the case where we are in a dummy
|
||||||
|
# run and query_start_loc_cpu is full of 0s
|
||||||
|
if max_query_len == 0:
|
||||||
|
max_query_len = attn_metadata.max_query_len
|
||||||
|
|
||||||
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
||||||
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
||||||
|
|
||||||
@ -152,12 +152,12 @@ def _make_metadata_with_slice(
|
|||||||
|
|
||||||
|
|
||||||
def split_attn_metadata(
|
def split_attn_metadata(
|
||||||
ubatch_slices: list[UbatchSlice],
|
ubatch_slices: list[UBatchSlice],
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
) -> list[CommonAttentionMetadata]:
|
) -> list[CommonAttentionMetadata]:
|
||||||
"""
|
"""
|
||||||
Creates a new CommonAttentionMetadata instance that corresponds to the
|
Creates a new CommonAttentionMetadata instance that corresponds to the
|
||||||
requests for each UbatchSlice in ubatch_slices.
|
requests for each UBatchSlice in ubatch_slices.
|
||||||
|
|
||||||
Note: This function does not modify common_attn_metadata
|
Note: This function does not modify common_attn_metadata
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
|||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -179,9 +180,11 @@ class EagleProposer:
|
|||||||
assert self.runner is not None
|
assert self.runner is not None
|
||||||
|
|
||||||
# FIXME: need to consider multiple kv_cache_groups
|
# FIXME: need to consider multiple kv_cache_groups
|
||||||
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
|
ubatch_id = dbo_current_ubatch_id()
|
||||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
attn_metadata_builder = \
|
||||||
draft_index=0)
|
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||||
|
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||||
|
common_attn_metadata=common_attn_metadata, draft_index=0)
|
||||||
|
|
||||||
# At this moment, we assume all eagle layers belong to the same KV
|
# At this moment, we assume all eagle layers belong to the same KV
|
||||||
# cache group, thus using the same attention metadata.
|
# cache group, thus using the same attention metadata.
|
||||||
@ -355,8 +358,9 @@ class EagleProposer:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
|
ubatch_id = dbo_current_ubatch_id()
|
||||||
tree_attn_metadata_builder = \
|
tree_attn_metadata_builder = \
|
||||||
self.runner.attn_groups[0][0].metadata_builder
|
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||||
assert isinstance(tree_attn_metadata_builder,
|
assert isinstance(tree_attn_metadata_builder,
|
||||||
TreeAttentionMetadataBuilder)
|
TreeAttentionMetadataBuilder)
|
||||||
|
|
||||||
|
|||||||
@ -64,8 +64,13 @@ class CPUModelRunner(GPUModelRunner):
|
|||||||
if not self.attn_groups[0]:
|
if not self.attn_groups[0]:
|
||||||
return
|
return
|
||||||
|
|
||||||
mb = getattr(self.attn_groups[0][0], "metadata_builder", None)
|
mb = getattr(self.attn_groups[0][0], "metadata_builders", None)
|
||||||
if not isinstance(mb, TorchSDPAMetadataBuilderV1):
|
if isinstance(mb, list):
|
||||||
|
if not isinstance(mb[0], TorchSDPAMetadataBuilderV1):
|
||||||
|
return
|
||||||
|
mb[0].reorder_batch(self.input_batch, scheduler_output)
|
||||||
|
return
|
||||||
|
elif not isinstance(mb, TorchSDPAMetadataBuilderV1):
|
||||||
# Encoder-only / rerank models do not benefit from reordering,
|
# Encoder-only / rerank models do not benefit from reordering,
|
||||||
# so we safely skip here.
|
# so we safely skip here.
|
||||||
return
|
return
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
@ -55,11 +56,12 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
|||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||||
is_pin_memory_available, round_up, supports_dynamo)
|
is_pin_memory_available, round_up, supports_dynamo)
|
||||||
|
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
create_fast_prefill_custom_backend,
|
create_fast_prefill_custom_backend,
|
||||||
reorder_batch_to_split_decodes_and_prefills)
|
reorder_batch_to_split_decodes_and_prefills, split_attn_metadata)
|
||||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -85,9 +87,12 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||||
KVConnectorModelRunnerMixin)
|
KVConnectorModelRunnerMixin)
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split
|
||||||
|
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
|
||||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||||
|
|
||||||
from .utils import (AttentionGroup, MultiModalBudget,
|
from .utils import (AttentionGroup, MultiModalBudget,
|
||||||
@ -105,6 +110,11 @@ else:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
|
||||||
|
# list when ubatching is enabled
|
||||||
|
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
||||||
|
AttnMetadataDict]
|
||||||
|
|
||||||
|
|
||||||
# Wrapper for ModelRunnerOutput to support overlapped execution.
|
# Wrapper for ModelRunnerOutput to support overlapped execution.
|
||||||
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||||
@ -274,6 +284,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# Request states.
|
# Request states.
|
||||||
self.requests: dict[str, CachedRequestState] = {}
|
self.requests: dict[str, CachedRequestState] = {}
|
||||||
|
self.comm_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
# Input Batch
|
# Input Batch
|
||||||
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
||||||
@ -872,10 +883,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
return encoder_seq_lens
|
return encoder_seq_lens
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self, scheduler_output: "SchedulerOutput"
|
||||||
scheduler_output: "SchedulerOutput",
|
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
||||||
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
|
Optional[SpecDecodeMetadata], np.ndarray,
|
||||||
np.ndarray, Optional[CommonAttentionMetadata], int]:
|
Optional[CommonAttentionMetadata], int, Optional[UBatchSlices],
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
:return: tuple[
|
:return: tuple[
|
||||||
attn_metadata: layer-to-attention_metadata mapping,
|
attn_metadata: layer-to-attention_metadata mapping,
|
||||||
@ -947,6 +959,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.query_start_loc.copy_to_gpu()
|
self.query_start_loc.copy_to_gpu()
|
||||||
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
|
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
|
||||||
|
|
||||||
|
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
|
||||||
|
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
|
||||||
|
num_tokens_unpadded)
|
||||||
|
ubatch_slices, num_tokens_after_padding = \
|
||||||
|
ubatch_split(max_num_scheduled_tokens,
|
||||||
|
num_tokens_unpadded,
|
||||||
|
num_tokens_padded,
|
||||||
|
self.vllm_config)
|
||||||
|
|
||||||
self.seq_lens.np[:num_reqs] = (
|
self.seq_lens.np[:num_reqs] = (
|
||||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
@ -1001,7 +1022,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
||||||
logits_indices)
|
logits_indices)
|
||||||
|
|
||||||
attn_metadata: dict[str, Any] = {}
|
attn_metadata: PerLayerAttnMetadata = {}
|
||||||
|
if ubatch_slices is not None:
|
||||||
|
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||||
|
|
||||||
# Used in the below loop.
|
# Used in the below loop.
|
||||||
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||||
@ -1075,7 +1098,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||||
# Prepare for cascade attention if enabled & beneficial.
|
# Prepare for cascade attention if enabled & beneficial.
|
||||||
common_prefix_len = 0
|
common_prefix_len = 0
|
||||||
builder = attn_group.metadata_builder
|
builder = attn_group.get_metadata_builder()
|
||||||
if self.cascade_attn_enabled:
|
if self.cascade_attn_enabled:
|
||||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||||
num_scheduled_tokens,
|
num_scheduled_tokens,
|
||||||
@ -1093,13 +1116,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
|
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_metadata_i = builder.build(
|
if ubatch_slices is not None:
|
||||||
common_prefix_len=common_prefix_len,
|
common_attn_metadata_list = split_attn_metadata(
|
||||||
common_attn_metadata=common_attn_metadata,
|
ubatch_slices, common_attn_metadata)
|
||||||
**extra_attn_metadata_args)
|
for ubid, common_attn_metadata in enumerate(
|
||||||
|
common_attn_metadata_list):
|
||||||
for layer_name in attn_group.layer_names:
|
assert common_attn_metadata.max_query_len == 1
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata_i = (attn_group.get_metadata_builder(
|
||||||
|
ubatch_id=ubid).build(
|
||||||
|
common_prefix_len=common_prefix_len,
|
||||||
|
common_attn_metadata=common_attn_metadata))
|
||||||
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
|
assert type(attn_metadata) is list
|
||||||
|
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||||
|
else:
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
|
attn_metadata_i = builder.build(
|
||||||
|
common_prefix_len=common_prefix_len,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
**extra_attn_metadata_args)
|
||||||
|
for layer_name in attn_group.layer_names:
|
||||||
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
# Hot-Swap lora model
|
# Hot-Swap lora model
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
@ -1107,7 +1144,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
return (attn_metadata, logits_indices, spec_decode_metadata,
|
return (attn_metadata, logits_indices, spec_decode_metadata,
|
||||||
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
||||||
max_num_scheduled_tokens)
|
max_num_scheduled_tokens, ubatch_slices,
|
||||||
|
num_tokens_after_padding)
|
||||||
|
|
||||||
def _compute_cascade_attn_prefix_len(
|
def _compute_cascade_attn_prefix_len(
|
||||||
self,
|
self,
|
||||||
@ -1508,7 +1546,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
# get raw model out of the cudagraph wrapper.
|
# get raw model out of the cudagraph wrapper.
|
||||||
if isinstance(self.model, CUDAGraphWrapper):
|
if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)):
|
||||||
return self.model.unwrap()
|
return self.model.unwrap()
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@ -1675,6 +1713,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
def get_dp_padding(self,
|
def get_dp_padding(self,
|
||||||
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
|
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Determines the total number of tokens that each rank will run.
|
||||||
|
All ranks will be padded out so that they run with the same number
|
||||||
|
of tokens
|
||||||
|
|
||||||
|
Returns: tuple[
|
||||||
|
num_pad_tokens: The number of tokens that will be added to the batch
|
||||||
|
num_tokens_after_padding: A tensor containing the total number of
|
||||||
|
tokens for each DP rank including padding.
|
||||||
|
]
|
||||||
|
"""
|
||||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||||
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||||
|
|
||||||
@ -1698,6 +1747,39 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||||
|
|
||||||
|
def get_local_padding(self, num_tokens_unpadded: int) -> int:
|
||||||
|
|
||||||
|
num_tokens_padded = num_tokens_unpadded
|
||||||
|
|
||||||
|
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
|
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
|
||||||
|
# Use piecewise CUDA graphs.
|
||||||
|
# Add padding to the batch size.
|
||||||
|
num_tokens_padded = self.vllm_config.pad_for_cudagraph(
|
||||||
|
num_tokens_unpadded)
|
||||||
|
else:
|
||||||
|
# Eager mode.
|
||||||
|
# Pad tokens to multiple of tensor_parallel_size when
|
||||||
|
# enabled collective fusion for SP
|
||||||
|
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||||
|
if self.vllm_config.compilation_config.pass_config. \
|
||||||
|
enable_sequence_parallelism and tp_size > 1:
|
||||||
|
num_tokens_padded = round_up(num_tokens_unpadded, tp_size)
|
||||||
|
|
||||||
|
num_pad_tokens = num_tokens_padded - num_tokens_unpadded
|
||||||
|
return num_pad_tokens
|
||||||
|
|
||||||
|
# This is where the second ubatch is adjusted to account for the padding.
|
||||||
|
# Should be called after attention metadata creation. This just pads
|
||||||
|
# the second ubatch slice out to the total number of tokens
|
||||||
|
# (num_tokens + padding)
|
||||||
|
def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices,
|
||||||
|
num_total_tokens: int):
|
||||||
|
padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start,
|
||||||
|
num_total_tokens)
|
||||||
|
ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice,
|
||||||
|
padded_second_ubatch_slice)
|
||||||
|
|
||||||
def _pool(
|
def _pool(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -1758,15 +1840,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
ubatch_slices: Optional[UBatchSlices] = None,
|
||||||
|
num_tokens_after_padding: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor],
|
) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor], torch.Tensor,
|
Optional[torch.Tensor], torch.Tensor,
|
||||||
Optional[IntermediateTensors], dict[str, Any]]:
|
Optional[IntermediateTensors], dict[str, Any]]:
|
||||||
|
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
|
if ubatch_slices:
|
||||||
# Padding for DP
|
assert num_tokens_after_padding is not None
|
||||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
|
num_input_tokens = int(num_tokens_after_padding[0].item() * 2)
|
||||||
num_input_tokens += num_pad
|
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
||||||
|
elif ubatch_slices is None:
|
||||||
|
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
|
||||||
|
num_pad, num_tokens_after_padding = self.get_dp_padding(
|
||||||
|
num_input_tokens)
|
||||||
|
num_input_tokens += num_pad
|
||||||
|
|
||||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||||
# modal outputs after that to ensure the correct order
|
# modal outputs after that to ensure the correct order
|
||||||
@ -1821,7 +1910,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
return (
|
return (
|
||||||
num_scheduled_tokens,
|
num_scheduled_tokens,
|
||||||
num_input_tokens,
|
num_input_tokens,
|
||||||
num_tokens_across_dp,
|
num_tokens_after_padding,
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
positions,
|
positions,
|
||||||
@ -2027,7 +2116,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||||
max_query_len) = self._prepare_inputs(scheduler_output)
|
max_query_len, ubatch_slices, num_tokens_after_padding
|
||||||
|
) = self._prepare_inputs(scheduler_output)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if self.prepare_inputs_event is not None:
|
if self.prepare_inputs_event is not None:
|
||||||
@ -2042,7 +2132,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
positions,
|
positions,
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
model_kwargs,
|
model_kwargs,
|
||||||
) = self._preprocess(scheduler_output, intermediate_tensors)
|
) = self._preprocess(scheduler_output, intermediate_tensors,
|
||||||
|
ubatch_slices, num_tokens_after_padding)
|
||||||
|
|
||||||
|
if ubatch_slices is not None:
|
||||||
|
num_input_tokens = num_input_tokens // 2
|
||||||
|
|
||||||
uniform_decode = (max_query_len
|
uniform_decode = (max_query_len
|
||||||
== self.uniform_decode_query_len) and (
|
== self.uniform_decode_query_len) and (
|
||||||
@ -2062,6 +2156,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
|
ubatch_slices=ubatch_slices,
|
||||||
), record_function_or_nullcontext("Forward"),
|
), record_function_or_nullcontext("Forward"),
|
||||||
self.maybe_get_kv_connector_output(scheduler_output) as
|
self.maybe_get_kv_connector_output(scheduler_output) as
|
||||||
kv_connector_output):
|
kv_connector_output):
|
||||||
@ -2441,10 +2536,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# CudagraphWraper and CudagraphDispatcher of vllm.
|
# CudagraphWraper and CudagraphDispatcher of vllm.
|
||||||
|
|
||||||
# wrap the model with full cudagraph wrapper if needed.
|
# wrap the model with full cudagraph wrapper if needed.
|
||||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \
|
||||||
|
and not self.parallel_config.enable_dbo:
|
||||||
self.model = CUDAGraphWrapper(self.model,
|
self.model = CUDAGraphWrapper(self.model,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
runtime_mode=CUDAGraphMode.FULL)
|
runtime_mode=CUDAGraphMode.FULL)
|
||||||
|
elif self.parallel_config.enable_dbo:
|
||||||
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||||
|
self.model = UBatchWrapper(self.model, self.vllm_config,
|
||||||
|
CUDAGraphMode.FULL, self.device)
|
||||||
|
else:
|
||||||
|
self.model = UBatchWrapper(self.model, self.vllm_config,
|
||||||
|
CUDAGraphMode.NONE, self.device)
|
||||||
|
|
||||||
def reload_weights(self) -> None:
|
def reload_weights(self) -> None:
|
||||||
assert getattr(self, "model", None) is not None, \
|
assert getattr(self, "model", None) is not None, \
|
||||||
@ -2642,6 +2745,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
force_attention: bool = False,
|
force_attention: bool = False,
|
||||||
uniform_decode: bool = False,
|
uniform_decode: bool = False,
|
||||||
|
allow_microbatching: bool = False,
|
||||||
skip_eplb: bool = False,
|
skip_eplb: bool = False,
|
||||||
is_profile: bool = False,
|
is_profile: bool = False,
|
||||||
create_mixed_batch: bool = False,
|
create_mixed_batch: bool = False,
|
||||||
@ -2667,12 +2771,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
(1 token) and prefill (multiple tokens) requests.
|
(1 token) and prefill (multiple tokens) requests.
|
||||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||||
"""
|
"""
|
||||||
|
ubatch_enabled = self.parallel_config.enable_dbo
|
||||||
|
num_tokens_across_dp = None
|
||||||
|
num_pad = 0
|
||||||
|
should_ubatch = False
|
||||||
|
if ubatch_enabled:
|
||||||
|
should_ubatch = num_tokens >= \
|
||||||
|
self.parallel_config.dbo_decode_token_threshold and \
|
||||||
|
allow_microbatching
|
||||||
|
|
||||||
|
(should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch(
|
||||||
|
num_tokens, num_tokens, should_ubatch, self.vllm_config)
|
||||||
|
|
||||||
|
# Currently the dummy run should only be ubatching during
|
||||||
|
# cuda graph capture, meaning all DP ranks should already
|
||||||
|
# have the same batch size
|
||||||
|
if num_tokens_across_dp is not None:
|
||||||
|
assert int(num_tokens_across_dp[0]) == num_tokens // 2
|
||||||
|
|
||||||
assert cudagraph_runtime_mode in {
|
assert cudagraph_runtime_mode in {
|
||||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||||
}
|
}
|
||||||
|
|
||||||
# Padding for DP
|
if not should_ubatch:
|
||||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||||
num_tokens += num_pad
|
num_tokens += num_pad
|
||||||
|
|
||||||
# If cudagraph_mode.decode_mode() == FULL and
|
# If cudagraph_mode.decode_mode() == FULL and
|
||||||
@ -2690,6 +2812,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# for GQA/MQA.
|
# for GQA/MQA.
|
||||||
max_query_len = self.uniform_decode_query_len if uniform_decode else \
|
max_query_len = self.uniform_decode_query_len if uniform_decode else \
|
||||||
num_tokens
|
num_tokens
|
||||||
|
if allow_microbatching:
|
||||||
|
assert self.uniform_decode_query_len == 1
|
||||||
|
assert uniform_decode is True
|
||||||
|
assert max_query_len == 1
|
||||||
|
|
||||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||||
# for dummy run with LoRA so that the num_reqs collectively
|
# for dummy run with LoRA so that the num_reqs collectively
|
||||||
@ -2728,12 +2854,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
|
|
||||||
attn_metadata: Optional[dict[str, Any]] = None
|
ubatch_slices = None
|
||||||
|
# We currently only microbatch if the number of tokens is
|
||||||
|
# over a certain threshold.
|
||||||
|
if should_ubatch:
|
||||||
|
# We only support decode-only cudagraphs
|
||||||
|
assert num_reqs == num_tokens
|
||||||
|
assert num_tokens % 2 == 0
|
||||||
|
ubatch_slices = [
|
||||||
|
UBatchSlice(slice(0, num_reqs // 2), slice(0,
|
||||||
|
num_tokens // 2)),
|
||||||
|
UBatchSlice(slice(num_reqs // 2, num_reqs),
|
||||||
|
slice(num_tokens // 2, num_tokens))
|
||||||
|
]
|
||||||
|
|
||||||
|
attn_metadata: Optional[PerLayerAttnMetadata] = None
|
||||||
|
|
||||||
# If force_attention is True, we always capture attention. Otherwise,
|
# If force_attention is True, we always capture attention. Otherwise,
|
||||||
# it only happens for cudagraph_runtime_mode=FULL.
|
# it only happens for cudagraph_runtime_mode=FULL.
|
||||||
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
attn_metadata = {}
|
attn_metadata = {}
|
||||||
|
if ubatch_slices is not None:
|
||||||
|
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||||
|
|
||||||
if create_mixed_batch:
|
if create_mixed_batch:
|
||||||
# In the mixed batch mode (used for FI warmup), we use
|
# In the mixed batch mode (used for FI warmup), we use
|
||||||
@ -2766,12 +2908,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
slot_mapping=self.input_batch.
|
slot_mapping=self.input_batch.
|
||||||
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
|
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
|
||||||
causal=True)
|
causal=True)
|
||||||
|
|
||||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||||
attn_metadata_i = attn_group.metadata_builder\
|
if ubatch_slices is not None:
|
||||||
.build_for_cudagraph_capture(common_attn_metadata)
|
common_attn_metadata_list = split_attn_metadata(
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
ubatch_slices, common_attn_metadata)
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
for ubid, common_attn_metadata in enumerate(
|
||||||
|
common_attn_metadata_list):
|
||||||
|
assert common_attn_metadata.max_query_len == 1
|
||||||
|
attn_metadata_i = (attn_group\
|
||||||
|
.get_metadata_builder(ubatch_id=ubid)\
|
||||||
|
.build_for_cudagraph_capture(common_attn_metadata))
|
||||||
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
|
assert type(attn_metadata) is list
|
||||||
|
attn_metadata[ubid][
|
||||||
|
layer_name] = attn_metadata_i
|
||||||
|
else:
|
||||||
|
assert type(attn_metadata) is dict
|
||||||
|
attn_metadata_i = attn_group.get_metadata_builder()\
|
||||||
|
.build_for_cudagraph_capture(common_attn_metadata)
|
||||||
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||||
num_scheduled_tokens, remove_lora):
|
num_scheduled_tokens, remove_lora):
|
||||||
@ -2818,13 +2974,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
f"Cudagraph runtime mode mismatch at dummy_run. "
|
f"Cudagraph runtime mode mismatch at dummy_run. "
|
||||||
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
|
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
|
||||||
|
|
||||||
|
if ubatch_slices is not None:
|
||||||
|
num_tokens = num_tokens // 2
|
||||||
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor):
|
batch_descriptor=batch_descriptor,
|
||||||
|
ubatch_slices=ubatch_slices):
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@ -3096,6 +3255,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
set_cudagraph_capturing_enabled(True)
|
set_cudagraph_capturing_enabled(True)
|
||||||
with freeze_gc(), graph_capture(device=self.device):
|
with freeze_gc(), graph_capture(device=self.device):
|
||||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||||
|
assert cudagraph_mode is not None
|
||||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||||
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
||||||
|
|
||||||
@ -3153,6 +3313,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
desc="Capturing CUDA graphs ({}, {})".format(
|
desc="Capturing CUDA graphs ({}, {})".format(
|
||||||
"decode" if uniform_decode else "mixed prefill-decode",
|
"decode" if uniform_decode else "mixed prefill-decode",
|
||||||
cudagraph_runtime_mode.name))
|
cudagraph_runtime_mode.name))
|
||||||
|
enable_dbo = self.parallel_config.enable_dbo
|
||||||
|
# DBO Only supports running Full cudagraphs with uniform
|
||||||
|
# decode lengths
|
||||||
|
if enable_dbo and uniform_decode:
|
||||||
|
for num_tokens in compilation_cases:
|
||||||
|
# If the number of tokens is greater than the microbatching
|
||||||
|
# threshold, don't generate a microbatched cudagraph
|
||||||
|
if (num_tokens
|
||||||
|
< self.parallel_config.dbo_decode_token_threshold):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(
|
||||||
|
self.compilation_config.cudagraph_num_of_warmups):
|
||||||
|
force_attention = (
|
||||||
|
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
||||||
|
self._dummy_run(num_tokens,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||||
|
force_attention=force_attention,
|
||||||
|
uniform_decode=True,
|
||||||
|
allow_microbatching=True,
|
||||||
|
skip_eplb=True)
|
||||||
|
|
||||||
|
# Graph Capture
|
||||||
|
self._dummy_run(num_tokens,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
uniform_decode=True,
|
||||||
|
allow_microbatching=True,
|
||||||
|
skip_eplb=True)
|
||||||
# We skip EPLB here since we don't want to record dummy metrics
|
# We skip EPLB here since we don't want to record dummy metrics
|
||||||
for num_tokens in compilation_cases:
|
for num_tokens in compilation_cases:
|
||||||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
||||||
@ -3219,14 +3408,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
) -> list[AttentionGroup]:
|
) -> list[AttentionGroup]:
|
||||||
attn_groups: list[AttentionGroup] = []
|
attn_groups: list[AttentionGroup] = []
|
||||||
for attn_backend, layer_names in attn_backends_map.items():
|
for attn_backend, layer_names in attn_backends_map.items():
|
||||||
attn_metadata_builder_i = attn_backend.get_builder_cls()(
|
attn_metadata_builders = []
|
||||||
|
attn_metadata_builders.append(attn_backend.get_builder_cls()(
|
||||||
kv_cache_spec,
|
kv_cache_spec,
|
||||||
layer_names,
|
layer_names,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
))
|
||||||
|
if self.parallel_config.enable_dbo:
|
||||||
|
attn_metadata_builders.append(
|
||||||
|
attn_backend.get_builder_cls()(
|
||||||
|
kv_cache_spec,
|
||||||
|
layer_names,
|
||||||
|
self.vllm_config,
|
||||||
|
self.device,
|
||||||
|
))
|
||||||
attn_group = AttentionGroup(attn_backend,
|
attn_group = AttentionGroup(attn_backend,
|
||||||
attn_metadata_builder_i,
|
attn_metadata_builders,
|
||||||
layer_names)
|
layer_names)
|
||||||
attn_groups.append(attn_group)
|
attn_groups.append(attn_group)
|
||||||
return attn_groups
|
return attn_groups
|
||||||
@ -3246,11 +3444,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
min_cg_builder_name = None
|
min_cg_builder_name = None
|
||||||
|
|
||||||
for attn_group in self._attn_group_iterator():
|
for attn_group in self._attn_group_iterator():
|
||||||
builder = attn_group.metadata_builder
|
builder = attn_group.get_metadata_builder()
|
||||||
if builder.cudagraph_support.value < min_cg_support.value:
|
if builder.cudagraph_support.value < min_cg_support.value:
|
||||||
min_cg_support = builder.cudagraph_support
|
min_cg_support = builder.cudagraph_support
|
||||||
min_cg_builder_name = builder.__class__.__name__
|
min_cg_builder_name = builder.__class__.__name__
|
||||||
|
|
||||||
# Flexible resolve the cudagraph mode
|
# Flexible resolve the cudagraph mode
|
||||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||||
# check cudagraph for mixed batch is supported
|
# check cudagraph for mixed batch is supported
|
||||||
@ -3316,7 +3513,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
is compatible (e.g., decode threshold is the same)
|
is compatible (e.g., decode threshold is the same)
|
||||||
"""
|
"""
|
||||||
for group in self._attn_group_iterator():
|
for group in self._attn_group_iterator():
|
||||||
attn_metadata_builder_i = group.metadata_builder
|
attn_metadata_builder_i = group.get_metadata_builder()
|
||||||
|
|
||||||
# check that if any backends reorder batches; that the reordering
|
# check that if any backends reorder batches; that the reordering
|
||||||
# is compatible (e.g., decode threshold is the same)
|
# is compatible (e.g., decode threshold is the same)
|
||||||
|
|||||||
303
vllm/v1/worker/gpu_ubatch_wrapper.py
Normal file
303
vllm/v1/worker/gpu_ubatch_wrapper.py
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import threading
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||||
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
|
from vllm.forward_context import (create_forward_context, get_forward_context,
|
||||||
|
override_forward_context)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class UbatchMetadata:
|
||||||
|
context: UBatchContext
|
||||||
|
input_ids: torch.Tensor
|
||||||
|
positions: torch.Tensor
|
||||||
|
inputs_embeds: Optional[torch.Tensor]
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors]
|
||||||
|
num_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CUDAGraphMetaData:
|
||||||
|
cudagraph: torch.cuda.CUDAGraph
|
||||||
|
ubatch_metadata: UbatchMetadata
|
||||||
|
outputs: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UBatchWrapper:
|
||||||
|
|
||||||
|
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||||
|
runtime_mode: CUDAGraphMode, device: torch.cuda.device):
|
||||||
|
self.runnable = runnable
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
self.comm_stream = torch.cuda.Stream(device=device)
|
||||||
|
# Two ubatch threads plus the main thread
|
||||||
|
self.ready_barrier = threading.Barrier(3)
|
||||||
|
|
||||||
|
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
|
||||||
|
|
||||||
|
self.cudagraph_wrapper = None
|
||||||
|
self.graph_pool = None
|
||||||
|
if runtime_mode is not CUDAGraphMode.NONE:
|
||||||
|
self.cudagraph_wrapper = CUDAGraphWrapper(
|
||||||
|
runnable, vllm_config, runtime_mode=runtime_mode)
|
||||||
|
self.graph_pool = current_platform.get_global_graph_pool()
|
||||||
|
|
||||||
|
def __getattr__(self, key: str):
|
||||||
|
# allow accessing the attributes of the runnable.
|
||||||
|
if hasattr(self.runnable, key):
|
||||||
|
return getattr(self.runnable, key)
|
||||||
|
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||||
|
f"cudagraph wrapper: {self.runnable}")
|
||||||
|
|
||||||
|
def unwrap(self) -> Callable:
|
||||||
|
# in case we need to access the original runnable.
|
||||||
|
return self.runnable
|
||||||
|
|
||||||
|
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Capture a cudagraph for a microbatched run.
|
||||||
|
|
||||||
|
The logic here is somewhat complicated because we need to make sure that
|
||||||
|
each of the ubatch threads initialize the cuda context before we start
|
||||||
|
the graph capture.
|
||||||
|
|
||||||
|
The flow is as follows:
|
||||||
|
1. The main thread starts up each ubatch thread. Each thread will
|
||||||
|
initialize its cuda context (torch.cuda.current_blas_handle())
|
||||||
|
before going to sleep upon entering the ubatch_context.
|
||||||
|
|
||||||
|
2. The main thread starts the graph capture and wakes up the first
|
||||||
|
ubatch thread.
|
||||||
|
|
||||||
|
3. Each ubatch thread runs the model to completion and returns the
|
||||||
|
completed output tensors back to the main thread.
|
||||||
|
|
||||||
|
4. The main thread stores the captured cudagraph along with its metadata
|
||||||
|
and returns
|
||||||
|
"""
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _capture_ubatch_thread(results, ubatch_metadata):
|
||||||
|
ubatch_context = ubatch_metadata.context
|
||||||
|
with torch.cuda.stream(ubatch_context.compute_stream):
|
||||||
|
_ = torch.cuda.current_blas_handle()
|
||||||
|
with torch.cuda.stream(ubatch_context.comm_stream):
|
||||||
|
_ = torch.cuda.current_blas_handle()
|
||||||
|
with ubatch_context:
|
||||||
|
model_output = model(
|
||||||
|
input_ids=ubatch_metadata.input_ids,
|
||||||
|
positions=ubatch_metadata.positions,
|
||||||
|
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||||
|
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append((ubatch_metadata.context.id, model_output))
|
||||||
|
|
||||||
|
results: list[tuple[int, torch.Tensor]] = []
|
||||||
|
compute_stream = ubatch_metadata[0].context.compute_stream
|
||||||
|
num_tokens = ubatch_metadata[0].num_tokens + \
|
||||||
|
ubatch_metadata[1].num_tokens
|
||||||
|
|
||||||
|
# Ubatches will manually manage the forward context, so we override
|
||||||
|
# it to None here so we can have it restored correctly later
|
||||||
|
with override_forward_context(None):
|
||||||
|
ubatch_threads = []
|
||||||
|
for metadata in ubatch_metadata:
|
||||||
|
thread = threading.Thread(target=_capture_ubatch_thread,
|
||||||
|
args=(
|
||||||
|
results,
|
||||||
|
metadata,
|
||||||
|
))
|
||||||
|
ubatch_threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||||
|
|
||||||
|
# Capture the cudagraph
|
||||||
|
cudagraph_metadata = \
|
||||||
|
CUDAGraphMetaData(
|
||||||
|
cudagraph=torch.cuda.CUDAGraph(),
|
||||||
|
ubatch_metadata=ubatch_metadata,
|
||||||
|
)
|
||||||
|
with torch.cuda.graph(cudagraph_metadata.cudagraph,
|
||||||
|
stream=compute_stream,
|
||||||
|
pool=self.graph_pool):
|
||||||
|
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||||
|
for thread in ubatch_threads:
|
||||||
|
thread.join()
|
||||||
|
sorted_results = [value for position, value in sorted(results)]
|
||||||
|
result = torch.cat(sorted_results, dim=0)
|
||||||
|
cudagraph_metadata.outputs = result
|
||||||
|
self.cudagraphs[num_tokens] = cudagraph_metadata
|
||||||
|
return cudagraph_metadata.outputs
|
||||||
|
|
||||||
|
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _ubatch_thread(results, model, ubatch_metadata):
|
||||||
|
with ubatch_metadata.context:
|
||||||
|
model_output = model(
|
||||||
|
input_ids=ubatch_metadata.input_ids,
|
||||||
|
positions=ubatch_metadata.positions,
|
||||||
|
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||||
|
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||||
|
)
|
||||||
|
results.append((ubatch_metadata.context.id, model_output))
|
||||||
|
|
||||||
|
results: list[tuple[int, torch.Tensor]] = []
|
||||||
|
|
||||||
|
# Ubatch threads will manually manage the forward context, so we
|
||||||
|
# override it to None here so we can have it restored correctly
|
||||||
|
# after both threads have finished
|
||||||
|
with override_forward_context(None):
|
||||||
|
ubatch_threads = []
|
||||||
|
for metadata in ubatch_metadata:
|
||||||
|
thread = threading.Thread(target=_ubatch_thread,
|
||||||
|
args=(
|
||||||
|
results,
|
||||||
|
model,
|
||||||
|
metadata,
|
||||||
|
))
|
||||||
|
ubatch_threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||||
|
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||||
|
for thread in ubatch_threads:
|
||||||
|
thread.join()
|
||||||
|
sorted_results = [value for position, value in sorted(results)]
|
||||||
|
result = torch.cat(sorted_results, dim=0)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
|
||||||
|
positions, inputs_embeds, intermediate_tensors,
|
||||||
|
compute_stream, dp_metadata, batch_descriptor,
|
||||||
|
cudagraph_runtime_mode) -> list[UbatchMetadata]:
|
||||||
|
|
||||||
|
# Create one forward context per ubatch
|
||||||
|
forward_contexts = []
|
||||||
|
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||||
|
forward_contexts.append(
|
||||||
|
create_forward_context(
|
||||||
|
attn_metadata[i] if attn_metadata is not None else None,
|
||||||
|
self.vllm_config,
|
||||||
|
dp_metadata=dp_metadata,
|
||||||
|
batch_descriptor=batch_descriptor,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode))
|
||||||
|
|
||||||
|
ubatch_ctxs = make_ubatch_contexts(
|
||||||
|
num_micro_batches=len(ubatch_slices),
|
||||||
|
comm_stream=self.comm_stream,
|
||||||
|
compute_stream=compute_stream,
|
||||||
|
forward_contexts=forward_contexts,
|
||||||
|
ready_barrier=self.ready_barrier)
|
||||||
|
|
||||||
|
ubatch_metadata: list[UbatchMetadata] = []
|
||||||
|
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||||
|
sliced_input_ids, sliced_positions, sliced_inputs_embeds, \
|
||||||
|
sliced_intermediate_tensors = \
|
||||||
|
self._slice_model_inputs(
|
||||||
|
ubatch_slice.token_slice, input_ids, positions,
|
||||||
|
inputs_embeds, intermediate_tensors)
|
||||||
|
ubatch_metadata.append(
|
||||||
|
UbatchMetadata(
|
||||||
|
context=ubatch_ctxs[i],
|
||||||
|
input_ids=sliced_input_ids,
|
||||||
|
positions=sliced_positions,
|
||||||
|
inputs_embeds=sliced_inputs_embeds,
|
||||||
|
intermediate_tensors=sliced_intermediate_tensors,
|
||||||
|
num_tokens=ubatch_slice.token_slice.stop -
|
||||||
|
ubatch_slice.token_slice.start))
|
||||||
|
|
||||||
|
return ubatch_metadata
|
||||||
|
|
||||||
|
def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions,
|
||||||
|
inputs_embeds, intermediate_tensors):
|
||||||
|
sliced_input_ids = input_ids[tokens_slice]
|
||||||
|
# if we are using mrope. Mrope adds an additional dimension to the
|
||||||
|
# positions tensor
|
||||||
|
if positions.ndim == 2:
|
||||||
|
sliced_positions = positions[:, tokens_slice]
|
||||||
|
else:
|
||||||
|
sliced_positions = positions[tokens_slice]
|
||||||
|
sliced_inputs_embeds = inputs_embeds[
|
||||||
|
tokens_slice] if inputs_embeds else None
|
||||||
|
sliced_intermediate_tensors = intermediate_tensors[
|
||||||
|
tokens_slice] if intermediate_tensors else None
|
||||||
|
|
||||||
|
return (sliced_input_ids, sliced_positions, sliced_inputs_embeds,
|
||||||
|
sliced_intermediate_tensors)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
batch_descriptor = forward_context.batch_descriptor
|
||||||
|
ubatch_slices = forward_context.ubatch_slices
|
||||||
|
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||||
|
|
||||||
|
# If there's no ubatching, just run the runnable object
|
||||||
|
if ubatch_slices is None:
|
||||||
|
if cudagraph_runtime_mode in (CUDAGraphMode.NONE,
|
||||||
|
CUDAGraphMode.PIECEWISE):
|
||||||
|
return self.runnable(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
assert self.cudagraph_wrapper is not None
|
||||||
|
return self.cudagraph_wrapper(*args, **kwargs)
|
||||||
|
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
num_tokens = (ubatch_slices[0].token_slice.stop -
|
||||||
|
ubatch_slices[0].token_slice.start) * 2
|
||||||
|
input_ids = kwargs['input_ids']
|
||||||
|
positions = kwargs['positions']
|
||||||
|
intermediate_tensors = kwargs['intermediate_tensors']
|
||||||
|
inputs_embeds = kwargs['inputs_embeds']
|
||||||
|
compute_stream = torch.cuda.current_stream()
|
||||||
|
|
||||||
|
dp_metadata = forward_context.dp_metadata
|
||||||
|
|
||||||
|
# We shouldn't be here unless we are running with multiple DP ranks
|
||||||
|
assert dp_metadata is not None
|
||||||
|
|
||||||
|
if num_tokens not in self.cudagraphs \
|
||||||
|
and cudagraph_runtime_mode is CUDAGraphMode.FULL:
|
||||||
|
ubatch_metadata = self._make_ubatch_metadata(
|
||||||
|
ubatch_slices=ubatch_slices,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
compute_stream=compute_stream,
|
||||||
|
dp_metadata=dp_metadata,
|
||||||
|
batch_descriptor=batch_descriptor,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||||
|
|
||||||
|
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||||
|
elif num_tokens in self.cudagraphs:
|
||||||
|
cudagraph_metadata = self.cudagraphs[num_tokens]
|
||||||
|
cudagraph_metadata.cudagraph.replay()
|
||||||
|
return cudagraph_metadata.outputs
|
||||||
|
else:
|
||||||
|
ubatch_metadata = self._make_ubatch_metadata(
|
||||||
|
ubatch_slices=ubatch_slices,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
compute_stream=compute_stream,
|
||||||
|
dp_metadata=dp_metadata,
|
||||||
|
batch_descriptor=batch_descriptor,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||||
|
return self._run_ubatches(ubatch_metadata, self.model)
|
||||||
155
vllm/v1/worker/ubatch_splitting.py
Normal file
155
vllm/v1/worker/ubatch_splitting.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.forward_context import DPMetadata
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import round_up
|
||||||
|
from vllm.v1.worker.ubatch_utils import (UBatchSlice, UBatchSlices,
|
||||||
|
is_second_ubatch_empty)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def should_ubatch_with_num_tokens(
|
||||||
|
should_ubatch: bool,
|
||||||
|
orig_num_tokens_per_ubatch: int,
|
||||||
|
padded_num_tokens_per_ubatch: int,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
) -> tuple[bool, Optional[torch.Tensor]]:
|
||||||
|
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
|
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
|
return DPMetadata.should_ubatch_across_dp(should_ubatch,
|
||||||
|
orig_num_tokens_per_ubatch,
|
||||||
|
padded_num_tokens_per_ubatch,
|
||||||
|
dp_size, dp_rank)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dp_padding_ubatch(
|
||||||
|
num_tokens_unpadded: int, num_tokens_padded: int,
|
||||||
|
should_attempt_ubatching: bool,
|
||||||
|
vllm_config: VllmConfig) -> tuple[bool, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||||
|
run with microbatching or none of them do. If this function decides
|
||||||
|
not to run with microbatching. It will "abort" meaning that no padding
|
||||||
|
information will be returned to the caller. It will return (False, None)
|
||||||
|
|
||||||
|
2. Determines the total number of tokens that each rank will run.
|
||||||
|
All ranks will be padded out so that the run with the same number
|
||||||
|
of tokens
|
||||||
|
|
||||||
|
Returns: tuple[
|
||||||
|
should_ubatch: Are all DP ranks going to microbatch
|
||||||
|
num_tokens_after_padding: A tensor containing the total number of
|
||||||
|
tokens per-microbatch for each DP rank including padding. Will be
|
||||||
|
None if should_ubatch if False
|
||||||
|
]
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert num_tokens_padded >= num_tokens_unpadded
|
||||||
|
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
|
if dp_size == 1:
|
||||||
|
# Early exit.
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
# If this DP rank doesn't want to attempt microbatching
|
||||||
|
if not should_attempt_ubatching:
|
||||||
|
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
|
||||||
|
False, 0, 0, vllm_config)
|
||||||
|
assert should_ubatch is False
|
||||||
|
assert num_tokens_across_dp is None
|
||||||
|
return should_ubatch, num_tokens_across_dp
|
||||||
|
|
||||||
|
# Round up to the next multiple of two for even divisibility
|
||||||
|
num_tokens_padded = round_up(num_tokens_padded, 2)
|
||||||
|
num_tokens_per_ubatch = num_tokens_padded // 2
|
||||||
|
should_ubatch = True
|
||||||
|
|
||||||
|
# Sanity Check that the existing padding isn't giving us an empty second
|
||||||
|
# ubatch. Abort if so
|
||||||
|
if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded):
|
||||||
|
logger.debug(
|
||||||
|
"Empty second µbatch detected: unpadded tokens: %s, padded "
|
||||||
|
"tokens: %s", num_tokens_unpadded, num_tokens_padded)
|
||||||
|
should_ubatch = False
|
||||||
|
|
||||||
|
# Note that we compute the number of padded tokens per ubatch
|
||||||
|
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
|
||||||
|
should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch,
|
||||||
|
vllm_config)
|
||||||
|
if not should_ubatch:
|
||||||
|
assert num_tokens_across_dp is None
|
||||||
|
return should_ubatch, num_tokens_across_dp
|
||||||
|
|
||||||
|
assert num_tokens_across_dp is not None
|
||||||
|
|
||||||
|
max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item())
|
||||||
|
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
|
||||||
|
dp_size,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int32)
|
||||||
|
return should_ubatch, num_tokens_after_padding
|
||||||
|
|
||||||
|
|
||||||
|
def ubatch_split(
|
||||||
|
max_num_scheduled_tokens: int,
|
||||||
|
num_tokens_unpadded: int,
|
||||||
|
num_tokens_padded: int,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Coordinates amongst all DP ranks to determine if and how the full batch
|
||||||
|
should be split into microbatches.
|
||||||
|
|
||||||
|
Returns: tuple[
|
||||||
|
ubatch_slices: if this is set then all DP ranks have agreed to
|
||||||
|
microbatch
|
||||||
|
num_tokens_after_padding: A tensor containing the total number of
|
||||||
|
tokens per-microbatch for each DP rank including padding. Will be
|
||||||
|
None if ubatch_slices is None
|
||||||
|
]
|
||||||
|
|
||||||
|
"""
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
# Don't bother with the should_ubatch handshaking unless microbatching
|
||||||
|
# is enabled
|
||||||
|
if not parallel_config.enable_dbo:
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
# Check preconditions for microbatching
|
||||||
|
should_attempt_ubatching = \
|
||||||
|
parallel_config.enable_dbo and \
|
||||||
|
num_tokens_unpadded >= \
|
||||||
|
parallel_config.dbo_decode_token_threshold \
|
||||||
|
and max_num_scheduled_tokens == 1
|
||||||
|
|
||||||
|
# Don't microbatch unless every other DP worker is also microbatching
|
||||||
|
num_tokens_after_padding = None
|
||||||
|
(should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch(
|
||||||
|
num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching,
|
||||||
|
vllm_config)
|
||||||
|
if not should_ubatch:
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
# This doesn't actually pad the ubatch slices. It just initializes the
|
||||||
|
# split point to the padded value so that padding can be applied
|
||||||
|
# to the second ubatch in pad_out_ubatch_slice after attention
|
||||||
|
# metadata creation
|
||||||
|
assert num_tokens_after_padding is not None
|
||||||
|
total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item())
|
||||||
|
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
|
||||||
|
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
|
||||||
|
num_tokens_unpadded)
|
||||||
|
|
||||||
|
# Note there's an assumption here that there's 1 token per request
|
||||||
|
ubatch_slices = [
|
||||||
|
UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice),
|
||||||
|
UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice)
|
||||||
|
]
|
||||||
|
|
||||||
|
return (ubatch_slices, num_tokens_after_padding)
|
||||||
19
vllm/v1/worker/ubatch_utils.py
Normal file
19
vllm/v1/worker/ubatch_utils.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UBatchSlice:
|
||||||
|
request_slice: slice
|
||||||
|
token_slice: slice
|
||||||
|
|
||||||
|
|
||||||
|
UBatchSlices: TypeAlias = list[UBatchSlice]
|
||||||
|
|
||||||
|
|
||||||
|
def is_second_ubatch_empty(orig_num_tokens_per_ubatch: int,
|
||||||
|
padded_num_tokens_per_ubatch: int) -> bool:
|
||||||
|
return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch
|
||||||
211
vllm/v1/worker/ubatching.py
Normal file
211
vllm/v1/worker/ubatching.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import threading
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import forward_context
|
||||||
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.utils import current_stream
|
||||||
|
|
||||||
|
_THREAD_ID_TO_CONTEXT: dict = {}
|
||||||
|
_CURRENT_CONTEXTS: list[Optional['UBatchContext']] = [None, None]
|
||||||
|
|
||||||
|
|
||||||
|
class UBatchContext:
|
||||||
|
"""
|
||||||
|
Context manager for micro-batching synchronization using threading events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
id: int,
|
||||||
|
comm_stream: torch.cuda.Stream,
|
||||||
|
compute_stream: torch.cuda.Stream,
|
||||||
|
forward_context: ForwardContext,
|
||||||
|
ready_barrier: threading.Barrier,
|
||||||
|
cpu_wait_event: threading.Event,
|
||||||
|
cpu_signal_event: threading.Event,
|
||||||
|
gpu_comm_done_event: torch.cuda.Event,
|
||||||
|
gpu_compute_done_event: torch.cuda.Event,
|
||||||
|
schedule: str = "default"):
|
||||||
|
self.id = id
|
||||||
|
self.comm_stream = comm_stream
|
||||||
|
self.compute_stream = compute_stream
|
||||||
|
self.forward_context = forward_context
|
||||||
|
self.ready_barrier = ready_barrier
|
||||||
|
self.cpu_wait_event = cpu_wait_event
|
||||||
|
self.cpu_signal_event = cpu_signal_event
|
||||||
|
self.current_stream = compute_stream
|
||||||
|
self.gpu_comm_done_event = gpu_comm_done_event
|
||||||
|
self.gpu_compute_done_event = gpu_compute_done_event
|
||||||
|
self.schedule = schedule
|
||||||
|
self.recv_hook = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||||
|
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
|
||||||
|
_CURRENT_CONTEXTS[self.id] = self
|
||||||
|
self.ready_barrier.wait()
|
||||||
|
|
||||||
|
self.cpu_wait_event.wait()
|
||||||
|
self.cpu_wait_event.clear()
|
||||||
|
self._restore_context()
|
||||||
|
# Assume we start on the compute stream
|
||||||
|
assert current_stream() == self.compute_stream
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||||
|
_CURRENT_CONTEXTS[self.id] = None
|
||||||
|
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||||
|
self.maybe_run_recv_hook()
|
||||||
|
self.cpu_signal_event.set()
|
||||||
|
self.cpu_wait_event.clear()
|
||||||
|
self.current_stream = self.compute_stream
|
||||||
|
torch.cuda.set_stream(self.current_stream)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _restore_context(self):
|
||||||
|
forward_context._forward_context = self.forward_context
|
||||||
|
torch.cuda.set_stream(self.current_stream)
|
||||||
|
|
||||||
|
def update_stream(self, stream):
|
||||||
|
self.current_stream = stream
|
||||||
|
torch.cuda.set_stream(self.current_stream)
|
||||||
|
|
||||||
|
def _signal_comm_done(self):
|
||||||
|
self.gpu_comm_done_event.record(self.comm_stream)
|
||||||
|
|
||||||
|
def _signal_compute_done(self):
|
||||||
|
self.gpu_compute_done_event.record(self.compute_stream)
|
||||||
|
|
||||||
|
def _wait_compute_done(self):
|
||||||
|
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||||
|
|
||||||
|
def _wait_comm_done(self):
|
||||||
|
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||||
|
|
||||||
|
def _cpu_yield(self):
|
||||||
|
# It is critical for correctness that only one thread is running
|
||||||
|
# at a time. These asserts just make sure that this is the only
|
||||||
|
# thread running before waking the other one up and going to sleep
|
||||||
|
assert forward_context._forward_context == self.forward_context
|
||||||
|
assert current_stream() == self.current_stream
|
||||||
|
assert not self.cpu_wait_event.is_set()
|
||||||
|
|
||||||
|
self.cpu_signal_event.set()
|
||||||
|
self.cpu_wait_event.wait()
|
||||||
|
self.cpu_wait_event.clear()
|
||||||
|
self._restore_context()
|
||||||
|
|
||||||
|
def switch_to_comm_sync(self):
|
||||||
|
self._signal_compute_done()
|
||||||
|
self.update_stream(self.comm_stream)
|
||||||
|
self._wait_comm_done()
|
||||||
|
|
||||||
|
def maybe_run_recv_hook(self):
|
||||||
|
if self.recv_hook is not None:
|
||||||
|
self.recv_hook()
|
||||||
|
self.recv_hook = None
|
||||||
|
|
||||||
|
def yield_(self):
|
||||||
|
self.current_stream = current_stream()
|
||||||
|
self._cpu_yield()
|
||||||
|
if self.current_stream != current_stream():
|
||||||
|
self.update_stream(self.current_stream)
|
||||||
|
|
||||||
|
def yield_and_switch_from_compute_to_comm(self):
|
||||||
|
assert current_stream() == self.compute_stream
|
||||||
|
self._signal_compute_done()
|
||||||
|
self._cpu_yield()
|
||||||
|
assert self.current_stream == self.compute_stream
|
||||||
|
self.update_stream(self.comm_stream)
|
||||||
|
self._wait_compute_done()
|
||||||
|
|
||||||
|
def yield_and_switch_from_comm_to_compute(self):
|
||||||
|
assert current_stream() == self.comm_stream
|
||||||
|
self._signal_comm_done()
|
||||||
|
self._cpu_yield()
|
||||||
|
assert self.current_stream == self.comm_stream
|
||||||
|
self.update_stream(self.compute_stream)
|
||||||
|
self._wait_comm_done()
|
||||||
|
|
||||||
|
|
||||||
|
def dbo_enabled() -> bool:
|
||||||
|
return len(_THREAD_ID_TO_CONTEXT) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def dbo_current_ubatch_id() -> int:
|
||||||
|
if len(_THREAD_ID_TO_CONTEXT) == 0:
|
||||||
|
return 0
|
||||||
|
return _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||||
|
|
||||||
|
|
||||||
|
def _register_ubatch_function(func):
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||||
|
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||||
|
ctx = _CURRENT_CONTEXTS[ctx_idx]
|
||||||
|
func(ctx, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
|
||||||
|
UBatchContext.yield_and_switch_from_compute_to_comm)
|
||||||
|
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
|
||||||
|
UBatchContext.yield_and_switch_from_comm_to_compute)
|
||||||
|
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
|
||||||
|
dbo_maybe_run_recv_hook = _register_ubatch_function(
|
||||||
|
UBatchContext.maybe_run_recv_hook)
|
||||||
|
dbo_switch_to_comm_sync = _register_ubatch_function(
|
||||||
|
UBatchContext.switch_to_comm_sync)
|
||||||
|
|
||||||
|
|
||||||
|
def dbo_register_recv_hook(recv_hook):
|
||||||
|
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||||
|
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||||
|
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2]
|
||||||
|
next_ctx.recv_hook = recv_hook
|
||||||
|
|
||||||
|
|
||||||
|
def make_ubatch_contexts(
|
||||||
|
num_micro_batches: int,
|
||||||
|
compute_stream: torch.cuda.Stream,
|
||||||
|
comm_stream: torch.cuda.Stream,
|
||||||
|
forward_contexts: list[ForwardContext],
|
||||||
|
ready_barrier: threading.Barrier,
|
||||||
|
schedule: str = "default",
|
||||||
|
) -> list[UBatchContext]:
|
||||||
|
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||||
|
"""
|
||||||
|
Create a context manager for micro-batching synchronization.
|
||||||
|
"""
|
||||||
|
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
|
||||||
|
gpu_comm_done_events = [
|
||||||
|
torch.cuda.Event() for _ in range(num_micro_batches)
|
||||||
|
]
|
||||||
|
gpu_compute_done_events = [
|
||||||
|
torch.cuda.Event() for _ in range(num_micro_batches)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(forward_contexts) == 2
|
||||||
|
|
||||||
|
ctxs = []
|
||||||
|
for i in range(num_micro_batches):
|
||||||
|
ctx = UBatchContext(id=i,
|
||||||
|
compute_stream=compute_stream,
|
||||||
|
comm_stream=comm_stream,
|
||||||
|
forward_context=forward_contexts[i],
|
||||||
|
ready_barrier=ready_barrier,
|
||||||
|
cpu_wait_event=cpu_events[i],
|
||||||
|
cpu_signal_event=cpu_events[(i + 1) %
|
||||||
|
num_micro_batches],
|
||||||
|
gpu_comm_done_event=gpu_comm_done_events[i],
|
||||||
|
gpu_compute_done_event=gpu_compute_done_events[i],
|
||||||
|
schedule=schedule)
|
||||||
|
ctxs.append(ctx)
|
||||||
|
|
||||||
|
return ctxs
|
||||||
@ -130,9 +130,17 @@ class MultiModalBudget:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AttentionGroup:
|
class AttentionGroup:
|
||||||
backend: type[AttentionBackend]
|
backend: type[AttentionBackend]
|
||||||
metadata_builder: AttentionMetadataBuilder
|
metadata_builders: list[AttentionMetadataBuilder]
|
||||||
layer_names: list[str]
|
layer_names: list[str]
|
||||||
|
|
||||||
|
def get_metadata_builder(self,
|
||||||
|
ubatch_id: Optional[int] = None
|
||||||
|
) -> AttentionMetadataBuilder:
|
||||||
|
if ubatch_id is None:
|
||||||
|
return self.metadata_builders[0]
|
||||||
|
assert len(self.metadata_builders) > ubatch_id
|
||||||
|
return self.metadata_builders[ubatch_id]
|
||||||
|
|
||||||
|
|
||||||
def sanity_check_mm_encoder_outputs(
|
def sanity_check_mm_encoder_outputs(
|
||||||
mm_embeddings: MultiModalEmbeddings,
|
mm_embeddings: MultiModalEmbeddings,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user