[Core] Refactor padding logic and pad for CUDA graphs before attention metadata building (#28579)

This commit is contained in:
Lucas Wilkinson 2025-11-26 14:07:13 -05:00 committed by GitHub
parent 430dd4d9eb
commit 56539cddac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 401 additions and 283 deletions

View File

@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren
```python
class BatchDescriptor(NamedTuple):
num_tokens: int
uniform_decode: bool = False
num_reqs: int
uniform: bool = False
has_lora: bool = False
```
where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`.
where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`.
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode.
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item.
!!! note
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (<https://github.com/vllm-project/vllm/pull/23679>), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs).

View File

@ -42,12 +42,24 @@ def _create_vllm_config(
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
mock_config.parallel_config = ParallelConfig()
mock_config.speculative_config = None # No speculative decoding
if not lora_config:
mock_config.lora_config = None
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1()
# mimic VllmConfig.__post_init__
if compilation_config.cudagraph_capture_sizes:
compilation_config.max_cudagraph_capture_size = (
compilation_config.cudagraph_capture_sizes[-1]
)
compilation_config.post_init_cudagraph_sizes()
mock_config.pad_for_cudagraph = (
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
)
return mock_config
@ -109,9 +121,11 @@ class TestCudagraphDispatcher:
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(
num_tokens=8,
uniform_decode=False,
uniform=False,
)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False
)
rt_mode, key = dispatcher.dispatch(desc_full_exact)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
@ -122,32 +136,37 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE
# 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=True, has_lora=False
)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.non_uniform
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.non_uniform
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
else:
assert rt_mode == CUDAGraphMode.NONE
# 3. No key match
desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_no_match)
rt_mode, key = dispatcher.dispatch(
num_tokens=15, uniform_decode=False, has_lora=False
)
assert rt_mode == CUDAGraphMode.NONE
assert key is None
assert key == BatchDescriptor(num_tokens=15)
# 4. Cascade attention should have a fall back mode
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True)
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.non_uniform
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
else:
assert rt_mode == CUDAGraphMode.NONE

View File

@ -35,23 +35,27 @@ class BatchDescriptor(NamedTuple):
"""
num_tokens: int
uniform_decode: bool = False
num_reqs: int | None = None
"""
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
the cudagraphs can handle any number of requests.
"""
uniform: bool = False
"""
True if all the requests in the batch have the same number of tokens.
"""
has_lora: bool = False
"""
Whether this batch has active LoRA adapters.
"""
@property
def non_uniform(self) -> "BatchDescriptor":
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
Return a relaxed version of current batch descriptor that is still compatible
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
)

View File

@ -930,31 +930,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if num_decodes > 0:
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (
self.enable_cuda_graph
and pure_decode
and num_decode_tokens <= self._decode_cudagraph_max_bs
)
if use_cudagraph:
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_decode_tokens
)
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[
1 + num_decodes : 1 + num_input_tokens
].fill_(paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
1
)
else:
num_input_tokens = num_decode_tokens
num_input_tokens = num_decode_tokens
attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph

View File

@ -107,6 +107,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
)
# -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
# -1 in the case we have a padded request (0 seq-len)
block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0)
return (
block_idx_last_computed_token,

View File

@ -72,6 +72,7 @@ class CommonAttentionMetadata:
num_reqs: int
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
@ -857,7 +858,9 @@ def split_decodes_and_prefills(
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
# 0-query len indicates a padded request; leave this at the back
# of the batch with the prefills
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0

View File

@ -4,6 +4,9 @@ from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
logger = init_logger(__name__)
class CudagraphDispatcher:
@ -28,7 +31,11 @@ class CudagraphDispatcher:
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.uniform_decode_query_len = (
1
if not self.vllm_config.speculative_config
else 1 + self.vllm_config.speculative_config.num_speculative_tokens
)
# Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
@ -36,25 +43,42 @@ class CudagraphDispatcher:
CUDAGraphMode.FULL: set(),
}
not_use_piecewise_compilation = (
not self.cudagraph_mode.requires_piecewise_compilation()
)
assert (
not_use_piecewise_compilation
not self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
or self.compilation_config.is_attention_compiled_piecewise()
), (
"Compilation mode should be CompilationMode.VLLM_COMPILE when "
"cudagraph_mode piecewise cudagraphs is used, "
"and attention should be in splitting_ops or "
"inductor splitting should be used. "
f"cudagraph_mode={self.cudagraph_mode}, "
f"cudagraph_mode={self.compilation_config.cudagraph_mode}, "
f"compilation_mode={self.compilation_config.mode}, "
f"splitting_ops={self.compilation_config.splitting_ops}"
)
self.keys_initialized = False
def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len
assert num_tokens_padded % uniform_decode_query_len == 0
else:
uniform_decode = False
num_reqs = min(num_tokens_padded, max_num_seqs)
return BatchDescriptor(
num_tokens=num_tokens_padded,
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
)
def add_cudagraph_key(
self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
):
@ -66,7 +90,9 @@ class CudagraphDispatcher:
def initialize_cudagraph_keys(
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
):
# This should be called only after attention backend is initialized.
# This should be called only after attention backend is initialized. So we can
# get the correct cudagraph mode after backend support is resolved.
self.cudagraph_mode = cudagraph_mode
# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
@ -86,9 +112,9 @@ class CudagraphDispatcher:
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(
num_tokens=bs, uniform_decode=False, has_lora=has_lora
),
self._create_padded_batch_descriptor(
bs, False, has_lora
).relax_for_mixed_batch_cudagraphs(),
)
# if decode cudagraph mode is FULL, and we don't already have mixed
@ -109,40 +135,49 @@ class CudagraphDispatcher:
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
BatchDescriptor(
num_tokens=bs, uniform_decode=True, has_lora=has_lora
),
self._create_padded_batch_descriptor(bs, True, has_lora),
)
self.keys_initialized = True
def dispatch(
self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False
) -> tuple[CUDAGraphMode, BatchDescriptor | None]:
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
use_cascade_attn: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
# if not initialized, just skip dispatching.
if not self.keys_initialized:
return CUDAGraphMode.NONE, None
if (
not self.keys_initialized
or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
non_uniform_key = batch_descriptor.non_uniform
# if a batch use cascade attention, bypass checking full cudagraphs
if not use_cascade_attn:
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc
# otherwise, check if non-uniform key exists
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key
# otherwise, check if the relaxed key exists
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, relaxed_batch_desc
# also check if non-uniform key exists for more "general"
# also check if the relaxed key exists for more "general"
# piecewise cudagraph
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, non_uniform_key
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
# finally, just return no cudagraphs
return CUDAGraphMode.NONE, None
# finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

View File

@ -9,6 +9,7 @@ from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
UBatchSlices,
check_ubatch_thresholds,
create_ubatch_slices,
@ -88,6 +89,17 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
return num_tokens_across_dp.cpu()
# This just pads the second ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
padded_second_ubatch_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
ubatch_slices[1] = UBatchSlice(
padded_second_ubatch_slice, padded_second_ubatch_slice
)
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
@ -220,11 +232,14 @@ def coordinate_batch_across_dp(
# to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation
assert num_tokens_after_padding is not None
token_split_point = int(num_tokens_after_padding[0].item()) // 2
num_tokens_padded = int(num_tokens_after_padding[0].item())
token_split_point = int(num_tokens_padded) // 2
assert num_scheduled_tokens_per_request is not None
ubatch_slices = create_ubatch_slices(
num_scheduled_tokens_per_request, token_split_point
)
ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded)
assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded
return (ubatch_slices, num_tokens_after_padding)

View File

@ -151,7 +151,6 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
UBatchSlices,
check_ubatch_thresholds,
)
@ -1239,17 +1238,13 @@ class GPUModelRunner(
self,
scheduler_output: "SchedulerOutput",
num_scheduled_tokens: np.ndarray,
max_num_scheduled_tokens: int,
) -> tuple[
torch.Tensor,
SpecDecodeMetadata | None,
UBatchSlices | None,
torch.Tensor | None,
]:
"""
:return: tuple[
logits_indices, spec_decode_metadata,
ubatch_slices, num_tokens_across_dp,
]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -1364,28 +1359,6 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded)
uniform_decode = (
max_num_scheduled_tokens == self.uniform_decode_query_len
) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set enforce_eager on the prefiller in
# a P/D setup and still use CUDA graphs (enabled by this padding) on the
# decoder.
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.parallel_config,
allow_microbatching=True,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens,
)
self.seq_lens.np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
)
@ -1486,15 +1459,15 @@ class GPUModelRunner(
return (
logits_indices,
spec_decode_metadata,
ubatch_slices,
num_tokens_across_dp,
)
def _build_attention_metadata(
self,
total_num_scheduled_tokens: int,
max_num_scheduled_tokens: int,
num_tokens: int,
num_reqs: int,
max_query_len: int,
num_tokens_padded: int | None = None,
num_reqs_padded: int | None = None,
ubatch_slices: UBatchSlices | None = None,
logits_indices: torch.Tensor | None = None,
use_spec_decode: bool = False,
@ -1505,6 +1478,9 @@ class GPUModelRunner(
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
num_tokens_padded = num_tokens_padded or num_tokens
num_reqs_padded = num_reqs_padded or num_reqs
logits_indices_padded = None
num_logits_indices = None
if logits_indices is not None:
@ -1522,28 +1498,13 @@ class GPUModelRunner(
self.dcp_rank,
self.parallel_config.cp_kv_cache_interleave_size,
)
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)
attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
# Used in the below loop
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1]
seq_lens = self.seq_lens.gpu[:num_reqs]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs
]
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
if self.dcp_world_size > 1:
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs]
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs]
spec_decode_common_attn_metadata = None
if for_cudagraph_capture:
# For some attention backends (e.g. FA) with sliding window models we need
# to make sure the backend see a max_seq_len that is larger to the sliding
@ -1559,6 +1520,22 @@ class GPUModelRunner(
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
# Used in the below loop, uses padded shapes
query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1]
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
seq_lens = self.seq_lens.gpu[:num_reqs_padded]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs_padded]
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
]
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
if self.dcp_world_size > 1:
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs_padded]
spec_decode_common_attn_metadata = None
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_gid, kv_cache_group in enumerate(
@ -1567,30 +1544,31 @@ class GPUModelRunner(
encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
num_scheduled_tokens or {},
kv_cache_group.kv_cache_spec,
num_reqs,
num_reqs_padded,
)
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros(
(num_reqs, 1),
(num_tokens_padded, 1),
dtype=torch.int32,
device=self.device,
)
slot_mapping = torch.zeros(
(total_num_scheduled_tokens,),
(num_tokens_padded,),
dtype=torch.int64,
device=self.device,
)
else:
blk_table = self.input_batch.block_table[kv_cache_gid]
blk_table_tensor = blk_table.get_device_tensor(num_reqs)
slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens]
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
@ -1598,9 +1576,9 @@ class GPUModelRunner(
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
num_actual_tokens=num_tokens_padded,
num_reqs=num_reqs_padded,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping,
@ -1631,9 +1609,11 @@ class GPUModelRunner(
extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs],
num_accepted_tokens=self.num_accepted_tokens.gpu[
:num_reqs_padded
],
num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
:num_reqs
:num_reqs_padded
],
)
@ -1677,6 +1657,7 @@ class GPUModelRunner(
def _compute_cascade_attn_prefix_lens(
self,
num_scheduled_tokens: np.ndarray,
num_computed_tokens: np.ndarray,
num_common_prefix_blocks: list[int],
) -> list[list[int]] | None:
"""
@ -1699,6 +1680,7 @@ class GPUModelRunner(
# 0 if cascade attention should not be used
cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
num_computed_tokens,
num_common_prefix_blocks[kv_cache_gid],
attn_group.kv_cache_spec,
attn_group.get_metadata_builder(),
@ -1711,6 +1693,7 @@ class GPUModelRunner(
def _compute_cascade_attn_prefix_len(
self,
num_scheduled_tokens: np.ndarray,
num_computed_tokens: np.ndarray,
num_common_prefix_blocks: int,
kv_cache_spec: KVCacheSpec,
attn_metadata_builder: AttentionMetadataBuilder,
@ -1777,10 +1760,7 @@ class GPUModelRunner(
# and the second kernel will get an empty input. While this is not
# a fundamental problem, our current implementation does not support
# this case.
num_reqs = len(num_scheduled_tokens)
common_prefix_len = min(
common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()
)
common_prefix_len = min(common_prefix_len, num_computed_tokens.min())
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (
common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size
@ -2334,19 +2314,6 @@ class GPUModelRunner(
log_stats=self.parallel_config.eplb_config.log_balancedness,
)
# This is where the second ubatch is adjusted to account for the padding.
# Should be called after attention metadata creation. This just pads
# the second ubatch slice out to the total number of tokens
# (num_tokens + padding)
@staticmethod
def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
padded_second_ubatch_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
ubatch_slices[1] = UBatchSlice(
padded_second_ubatch_slice, padded_second_ubatch_slice
)
def _pool(
self,
hidden_states: torch.Tensor,
@ -2391,18 +2358,7 @@ class GPUModelRunner(
pooler_output=pooler_output,
)
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
if (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and hasattr(self, "cudagraph_batch_sizes")
and self.cudagraph_batch_sizes
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]
):
# Use CUDA graphs.
# Add padding to the batch size.
return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens)
# Eager mode.
def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
@ -2738,6 +2694,87 @@ class GPUModelRunner(
**model_kwargs,
)
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
num_reqs: int,
num_scheduled_tokens_np: np.ndarray,
max_num_scheduled_tokens: int,
use_cascade_attn: bool,
allow_microbatching: bool = True,
force_eager: bool = False,
# For cudagraph capture TODO(lucas): Refactor how we capture cudagraphs (will
# be improved in model runner v2)
force_uniform_decode: bool | None = None,
force_has_lora: bool | None = None,
) -> tuple[
CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None
]:
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
uniform_decode = (
(
(max_num_scheduled_tokens == self.uniform_decode_query_len)
and (num_tokens_padded == max_num_scheduled_tokens * num_reqs)
)
if force_uniform_decode is None
else force_uniform_decode
)
has_lora = (
len(self.input_batch.lora_id_to_lora_request) > 0
if force_has_lora is None
else force_has_lora
)
dispatch_cudagraph = (
lambda num_tokens: self.cudagraph_dispatcher.dispatch(
num_tokens=num_tokens,
has_lora=has_lora,
use_cascade_attn=use_cascade_attn,
uniform_decode=uniform_decode,
)
if not force_eager
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
)
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
num_tokens_padded = batch_descriptor.num_tokens
# Extra coordination when running data-parallel since we need to coordinate
# across ranks
ubatch_slices, num_tokens_across_dp = None, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
# in a P/D setup and still use CUDA graphs (enabled by this padding) on the
# decoder.
allow_dp_padding = (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
)
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_padded,
parallel_config=self.parallel_config,
allow_microbatching=allow_microbatching,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
)
# Extract DP padding if there is any
if num_tokens_across_dp is not None:
dp_rank = self.parallel_config.data_parallel_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
# Assert to make sure the agreed upon token count is correct otherwise
# num_tokens_across_dp will no-longer be valid
assert batch_descriptor.num_tokens == num_tokens_padded
return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp
@torch.inference_mode()
def execute_model(
self,
@ -2790,7 +2827,7 @@ class GPUModelRunner(
# returns True. before returning early here we call
# dummy run to ensure coordinate_batch_across_dp
# is called into to avoid out of sync issues.
self._dummy_run(self._get_num_input_tokens(1))
self._dummy_run(1)
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
@ -2809,36 +2846,63 @@ class GPUModelRunner(
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
(
logits_indices,
spec_decode_metadata,
ubatch_slices,
num_tokens_across_dp,
) = self._prepare_inputs(
scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens
scheduler_output,
num_scheduled_tokens_np,
)
cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO)
if self.cascade_attn_enabled and ubatch_slices is None:
if self.cascade_attn_enabled and not self.parallel_config.enable_dbo:
# Pre-compute cascade attention prefix lengths
# NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
num_scheduled_tokens_np,
self.input_batch.num_computed_tokens_cpu[:num_reqs],
scheduler_output.num_common_prefix_blocks,
)
# TODO(lucas): move cudagraph dispatching here:
# https://github.com/vllm-project/vllm/issues/23789
(
cudagraph_mode,
batch_desc,
ubatch_slices,
num_tokens_across_dp,
) = self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs,
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=max_num_scheduled_tokens,
use_cascade_attn=cascade_attn_prefix_lens is not None,
)
logger.debug(
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
"ubatch_slices: %s, num_tokens_across_dp: %s",
cudagraph_mode,
batch_desc,
ubatch_slices,
num_tokens_across_dp,
)
num_tokens_padded = batch_desc.num_tokens
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
attn_metadata, spec_decode_common_attn_metadata = (
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
(attn_metadata, spec_decode_common_attn_metadata) = (
self._build_attention_metadata(
total_num_scheduled_tokens=total_num_scheduled_tokens,
max_num_scheduled_tokens=max_num_scheduled_tokens,
num_tokens=num_tokens_unpadded,
num_tokens_padded=num_tokens_padded if pad_attn else None,
num_reqs=num_reqs,
num_reqs_padded=num_reqs_padded if pad_attn else None,
max_query_len=max_num_scheduled_tokens,
ubatch_slices=ubatch_slices,
logits_indices=logits_indices,
use_spec_decode=use_spec_decode,
@ -2847,49 +2911,22 @@ class GPUModelRunner(
)
)
dp_rank = self.parallel_config.data_parallel_rank
if ubatch_slices:
assert num_tokens_across_dp is not None
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
elif num_tokens_across_dp is not None:
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
else:
num_input_tokens = self._get_num_input_tokens(
scheduler_output.total_num_scheduled_tokens
)
(
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
ec_connector_output,
) = self._preprocess(
scheduler_output, num_input_tokens, intermediate_tensors
)
uniform_decode = (
max_num_scheduled_tokens == self.uniform_decode_query_len
) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
batch_desc = BatchDescriptor(
num_tokens=num_input_tokens,
uniform_decode=uniform_decode,
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
)
cudagraph_runtime_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(
batch_desc,
use_cascade_attn=cascade_attn_prefix_lens is not None,
)
(
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
ec_connector_output,
) = self._preprocess(
scheduler_output, num_tokens_padded, intermediate_tensors
)
# Set cudagraph mode to none if calc_kv_scales is true.
# KV scales calculation involves dynamic operations that are incompatible
# with CUDA graph capture.
if self.calculate_kv_scales:
cudagraph_runtime_mode = CUDAGraphMode.NONE
cudagraph_mode = CUDAGraphMode.NONE
# Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False
@ -2899,10 +2936,10 @@ class GPUModelRunner(
set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens=num_tokens_padded,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_mode,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices,
),
record_function_or_nullcontext("gpu_model_runner: forward"),
@ -2952,7 +2989,7 @@ class GPUModelRunner(
if not get_pp_group().is_last_rank:
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
self.vllm_config, num_input_tokens
self.vllm_config, num_tokens_padded
)
}
get_pp_group().send_tensor_dict(
@ -3841,52 +3878,44 @@ class GPUModelRunner(
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
num_tokens_unpadded = int(num_scheduled_tokens.sum())
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
# Disable DP padding when running eager
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
# We currently only microbatch if the number of tokens is
# over a certain threshold.
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=total_num_scheduled_tokens,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=allow_microbatching,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=total_num_scheduled_tokens,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens,
)
num_tokens_after_padding = num_tokens
if num_tokens_across_dp is not None:
dp_rank = self.parallel_config.data_parallel_rank
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(
num_tokens=num_tokens_after_padding,
uniform_decode=uniform_decode,
has_lora=activate_lora and self.lora_config is not None,
)
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs,
num_scheduled_tokens_np=num_scheduled_tokens,
max_num_scheduled_tokens=max_query_len,
use_cascade_attn=False,
allow_microbatching=allow_microbatching,
force_eager=is_profile
or (cudagraph_runtime_mode == CUDAGraphMode.NONE),
# `force_uniform_decode` is used for cudagraph capture; because for
# capturing mixed prefill-decode batches, we sometimes use
# num_tokens == num_reqs which looks like a uniform decode batch to the
# dispatcher; but we actually want to capture a piecewise cudagraph
force_uniform_decode=uniform_decode,
# `force_has_lora` is used for cudagraph capture; because LoRA is
# activated later in the context manager, but we need to know the
# LoRA state when determining the batch descriptor for capture
force_has_lora=activate_lora,
)
if not is_profile
else (CUDAGraphMode.NONE, None)
)
if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for cudagraph capture
assert (
cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode == _cg_mode
), (
f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
)
if cudagraph_runtime_mode is None:
cudagraph_runtime_mode = _cudagraph_mode
else:
cudagraph_runtime_mode = _cg_mode
assert cudagraph_runtime_mode == _cudagraph_mode, (
f"Cudagraph runtime mode mismatch in dummy_run. "
f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}."
)
num_tokens_padded = batch_desc.num_tokens
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
attn_metadata: PerLayerAttnMetadata | None = None
@ -3909,9 +3938,9 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu()
attn_metadata, _ = self._build_attention_metadata(
total_num_scheduled_tokens=num_tokens,
max_num_scheduled_tokens=max_query_len,
num_reqs=num_reqs,
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs_padded,
max_query_len=max_query_len,
ubatch_slices=ubatch_slices,
for_cudagraph_capture=True,
)
@ -3924,29 +3953,29 @@ class GPUModelRunner(
remove_lora,
):
# Make sure padding doesn't exceed max_num_tokens
assert num_tokens_after_padding <= self.max_num_tokens
model_kwargs = self._init_model_kwargs(num_tokens_after_padding)
assert num_tokens_padded <= self.max_num_tokens
model_kwargs = self._init_model_kwargs(num_tokens_padded)
if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding]
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
model_kwargs = {
**model_kwargs,
**self._dummy_mm_kwargs(num_reqs),
}
elif self.enable_prompt_embeds:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding]
model_kwargs = self._init_model_kwargs(num_tokens_after_padding)
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
model_kwargs = self._init_model_kwargs(num_tokens_padded)
else:
input_ids = self.input_ids.gpu[:num_tokens_after_padding]
input_ids = self.input_ids.gpu[:num_tokens_padded]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens_after_padding]
positions = self.mrope_positions.gpu[:, :num_tokens_padded]
elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_tokens_after_padding]
positions = self.xdrope_positions.gpu[:, :num_tokens_padded]
else:
positions = self.positions.gpu[:num_tokens_after_padding]
positions = self.positions.gpu[:num_tokens_padded]
if get_pp_group().is_first_rank:
intermediate_tensors = None
@ -3961,26 +3990,26 @@ class GPUModelRunner(
)
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens_after_padding, None, False
num_tokens_padded, None, False
)
if ubatch_slices is not None:
# Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in
# the padding refactor.
num_tokens_after_padding = ubatch_slices[0].num_tokens
num_tokens_padded = ubatch_slices[0].num_tokens
if num_tokens_across_dp is not None:
num_tokens_across_dp[:] = num_tokens_after_padding
num_tokens_across_dp[:] = num_tokens_padded
with (
self.maybe_randomize_inputs(input_ids),
set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens_after_padding,
num_tokens=num_tokens_padded,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices,
),
):
@ -4706,8 +4735,7 @@ class GPUModelRunner(
# Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode.
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
self.compilation_config.cudagraph_mode = cudagraph_mode
self.cudagraph_dispatcher.initialize_cudagraph_keys(
cudagraph_mode, self.uniform_decode_query_len
)

View File

@ -8,12 +8,13 @@ from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import (
ensure_model_parallel_initialized,
init_distributed_environment,
@ -487,6 +488,7 @@ class Worker(WorkerBase):
hidden_states, last_hidden_states = self.model_runner._dummy_run(
num_tokens=max_num_reqs,
skip_eplb=True,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
@ -534,12 +536,39 @@ class Worker(WorkerBase):
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
self.vllm_config, num_input_tokens
all_gather_tensors = {}
compilation_config = self.vllm_config.compilation_config
parallel_config = self.vllm_config.parallel_config
if (
parallel_config.pipeline_parallel_size > 1
and compilation_config.pass_config.enable_sequence_parallelism
and forward_pass
):
# currently only supported by V1 GPUModelRunner
assert isinstance(self.model_runner, GPUModelRunner)
num_scheduled_tokens_np = np.array(
list(scheduler_output.num_scheduled_tokens.values()),
dtype=np.int32,
)
}
# TODO(lucas): This is pretty gross; ideally we should only ever call
# `_determine_batch_execution_and_padding` once (will get called again
# in `execute_model`) but this requires a larger refactor of PP.
_, batch_desc, _, _ = (
self.model_runner._determine_batch_execution_and_padding(
num_tokens=num_scheduled_tokens,
num_reqs=len(num_scheduled_tokens_np),
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
use_cascade_attn=False, # TODO(lucas): Handle cascade attention
)
)
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
self.vllm_config, batch_desc.num_tokens
)
}
if forward_pass and not get_pp_group().is_first_rank:
tensor_dict = get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),