[Core] Refactor padding logic and pad for CUDA graphs before attention metadata building (#28579)

This commit is contained in:
Lucas Wilkinson 2025-11-26 14:07:13 -05:00 committed by GitHub
parent 430dd4d9eb
commit 56539cddac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 401 additions and 283 deletions

View File

@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren
```python ```python
class BatchDescriptor(NamedTuple): class BatchDescriptor(NamedTuple):
num_tokens: int num_tokens: int
uniform_decode: bool = False num_reqs: int
uniform: bool = False
has_lora: bool = False
``` ```
where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`. where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`.
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode. The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item.
!!! note !!! note
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (<https://github.com/vllm-project/vllm/pull/23679>), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs). The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (<https://github.com/vllm-project/vllm/pull/23679>), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs).

View File

@ -42,12 +42,24 @@ def _create_vllm_config(
mock_config.compilation_config = compilation_config mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs) mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
mock_config.parallel_config = ParallelConfig() mock_config.parallel_config = ParallelConfig()
mock_config.speculative_config = None # No speculative decoding
if not lora_config: if not lora_config:
mock_config.lora_config = None mock_config.lora_config = None
# Mimic the behavior of VllmConfig.__post_init__() # Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE: if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1() compilation_config.set_splitting_ops_for_v1()
# mimic VllmConfig.__post_init__
if compilation_config.cudagraph_capture_sizes:
compilation_config.max_cudagraph_capture_size = (
compilation_config.cudagraph_capture_sizes[-1]
)
compilation_config.post_init_cudagraph_sizes()
mock_config.pad_for_cudagraph = (
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
)
return mock_config return mock_config
@ -109,9 +121,11 @@ class TestCudagraphDispatcher:
# 1. non-uniform batch, size in cudagraph size list # 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor( desc_full_exact = BatchDescriptor(
num_tokens=8, num_tokens=8,
uniform_decode=False, uniform=False,
)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False
) )
rt_mode, key = dispatcher.dispatch(desc_full_exact)
if cudagraph_mode_str == "FULL": if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact assert key == desc_full_exact
@ -122,32 +136,37 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
# 2. uniform decode batch, size in cudagraph size list # 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True) desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
rt_mode, key = dispatcher.dispatch(desc_uniform_exact) rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=True, has_lora=False
)
if cudagraph_mode_str == "FULL": if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.non_uniform assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]: elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
assert rt_mode == CUDAGraphMode.FULL assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact assert key == desc_uniform_exact
elif cudagraph_mode_str == "PIECEWISE": elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.non_uniform assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
else: else:
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
# 3. No key match # 3. No key match
desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False) rt_mode, key = dispatcher.dispatch(
rt_mode, key = dispatcher.dispatch(desc_no_match) num_tokens=15, uniform_decode=False, has_lora=False
)
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
assert key is None assert key == BatchDescriptor(num_tokens=15)
# 4. Cascade attention should have a fall back mode # 4. Cascade attention should have a fall back mode
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True) rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.non_uniform assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
else: else:
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE

View File

@ -35,23 +35,27 @@ class BatchDescriptor(NamedTuple):
""" """
num_tokens: int num_tokens: int
uniform_decode: bool = False num_reqs: int | None = None
""" """
False can also be used for an uniform decode batch to dispatch to the Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
cudagraph supporting non-uniform batches. the cudagraphs can handle any number of requests.
"""
uniform: bool = False
"""
True if all the requests in the batch have the same number of tokens.
""" """
has_lora: bool = False has_lora: bool = False
""" """
Whether this batch has active LoRA adapters. Whether this batch has active LoRA adapters.
""" """
@property def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
def non_uniform(self) -> "BatchDescriptor":
""" """
Return a non-uniform version of current batch descriptor. Return a relaxed version of current batch descriptor that is still compatible
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
""" """
return BatchDescriptor( return BatchDescriptor(
self.num_tokens, uniform_decode=False, has_lora=self.has_lora self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
) )

View File

@ -930,31 +930,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if num_decodes > 0: if num_decodes > 0:
pure_decode = num_prefills == 0 pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = ( use_cudagraph = (
self.enable_cuda_graph self.enable_cuda_graph
and pure_decode and pure_decode
and num_decode_tokens <= self._decode_cudagraph_max_bs and num_decode_tokens <= self._decode_cudagraph_max_bs
) )
if use_cudagraph: num_input_tokens = num_decode_tokens
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_decode_tokens
)
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[
1 + num_decodes : 1 + num_input_tokens
].fill_(paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
1
)
else:
num_input_tokens = num_decode_tokens
attn_metadata.decode_wrapper = self._get_decode_wrapper( attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph num_input_tokens, use_cudagraph

View File

@ -107,6 +107,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
) )
# -1 in case it's non-computed and causes later issues with indexing # -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0) block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
# -1 in the case we have a padded request (0 seq-len)
block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0)
return ( return (
block_idx_last_computed_token, block_idx_last_computed_token,

View File

@ -72,6 +72,7 @@ class CommonAttentionMetadata:
num_reqs: int num_reqs: int
"""Number of requests""" """Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int num_actual_tokens: int
"""Total number of tokens in batch""" """Total number of tokens in batch"""
max_query_len: int max_query_len: int
@ -857,7 +858,9 @@ def split_decodes_and_prefills(
if require_uniform: if require_uniform:
is_prefill = query_lens != query_lens[0] is_prefill = query_lens != query_lens[0]
else: else:
is_prefill = query_lens > decode_threshold # 0-query len indicates a padded request; leave this at the back
# of the batch with the prefills
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
if not torch.any(is_prefill): if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0 return num_reqs, 0, num_tokens, 0

View File

@ -4,6 +4,9 @@ from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
logger = init_logger(__name__)
class CudagraphDispatcher: class CudagraphDispatcher:
@ -28,7 +31,11 @@ class CudagraphDispatcher:
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.cudagraph_mode = self.compilation_config.cudagraph_mode self.uniform_decode_query_len = (
1
if not self.vllm_config.speculative_config
else 1 + self.vllm_config.speculative_config.num_speculative_tokens
)
# Dict to store valid cudagraph dispatching keys. # Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = { self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
@ -36,25 +43,42 @@ class CudagraphDispatcher:
CUDAGraphMode.FULL: set(), CUDAGraphMode.FULL: set(),
} }
not_use_piecewise_compilation = (
not self.cudagraph_mode.requires_piecewise_compilation()
)
assert ( assert (
not_use_piecewise_compilation not self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
or self.compilation_config.is_attention_compiled_piecewise() or self.compilation_config.is_attention_compiled_piecewise()
), ( ), (
"Compilation mode should be CompilationMode.VLLM_COMPILE when " "Compilation mode should be CompilationMode.VLLM_COMPILE when "
"cudagraph_mode piecewise cudagraphs is used, " "cudagraph_mode piecewise cudagraphs is used, "
"and attention should be in splitting_ops or " "and attention should be in splitting_ops or "
"inductor splitting should be used. " "inductor splitting should be used. "
f"cudagraph_mode={self.cudagraph_mode}, " f"cudagraph_mode={self.compilation_config.cudagraph_mode}, "
f"compilation_mode={self.compilation_config.mode}, " f"compilation_mode={self.compilation_config.mode}, "
f"splitting_ops={self.compilation_config.splitting_ops}" f"splitting_ops={self.compilation_config.splitting_ops}"
) )
self.keys_initialized = False self.keys_initialized = False
def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len
assert num_tokens_padded % uniform_decode_query_len == 0
else:
uniform_decode = False
num_reqs = min(num_tokens_padded, max_num_seqs)
return BatchDescriptor(
num_tokens=num_tokens_padded,
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
)
def add_cudagraph_key( def add_cudagraph_key(
self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
): ):
@ -66,7 +90,9 @@ class CudagraphDispatcher:
def initialize_cudagraph_keys( def initialize_cudagraph_keys(
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
): ):
# This should be called only after attention backend is initialized. # This should be called only after attention backend is initialized. So we can
# get the correct cudagraph mode after backend support is resolved.
self.cudagraph_mode = cudagraph_mode
# LoRA activation cases to specialize the cuda graphs on # LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config: if self.vllm_config.lora_config:
@ -86,9 +112,9 @@ class CudagraphDispatcher:
): ):
self.add_cudagraph_key( self.add_cudagraph_key(
cudagraph_mode.mixed_mode(), cudagraph_mode.mixed_mode(),
BatchDescriptor( self._create_padded_batch_descriptor(
num_tokens=bs, uniform_decode=False, has_lora=has_lora bs, False, has_lora
), ).relax_for_mixed_batch_cudagraphs(),
) )
# if decode cudagraph mode is FULL, and we don't already have mixed # if decode cudagraph mode is FULL, and we don't already have mixed
@ -109,40 +135,49 @@ class CudagraphDispatcher:
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key( self.add_cudagraph_key(
CUDAGraphMode.FULL, CUDAGraphMode.FULL,
BatchDescriptor( self._create_padded_batch_descriptor(bs, True, has_lora),
num_tokens=bs, uniform_decode=True, has_lora=has_lora
),
) )
self.keys_initialized = True self.keys_initialized = True
def dispatch( def dispatch(
self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False self,
) -> tuple[CUDAGraphMode, BatchDescriptor | None]: num_tokens: int,
uniform_decode: bool,
has_lora: bool,
use_cascade_attn: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
""" """
Given conditions(e.g.,batch descriptor and if using cascade attention), Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor. dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform). to a graph that supports a more general batch (uniform to non-uniform).
""" """
# if not initialized, just skip dispatching. if (
if not self.keys_initialized: not self.keys_initialized
return CUDAGraphMode.NONE, None or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
non_uniform_key = batch_descriptor.non_uniform
# if a batch use cascade attention, bypass checking full cudagraphs
if not use_cascade_attn: if not use_cascade_attn:
# check if key exists for full cudagraph # check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor return CUDAGraphMode.FULL, batch_desc
# otherwise, check if non-uniform key exists # otherwise, check if the relaxed key exists
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key return CUDAGraphMode.FULL, relaxed_batch_desc
# also check if non-uniform key exists for more "general" # also check if the relaxed key exists for more "general"
# piecewise cudagraph # piecewise cudagraph
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, non_uniform_key return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
# finally, just return no cudagraphs # finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, None return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

View File

@ -9,6 +9,7 @@ from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import ( from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
UBatchSlices, UBatchSlices,
check_ubatch_thresholds, check_ubatch_thresholds,
create_ubatch_slices, create_ubatch_slices,
@ -88,6 +89,17 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
return num_tokens_across_dp.cpu() return num_tokens_across_dp.cpu()
# This just pads the second ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
padded_second_ubatch_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
ubatch_slices[1] = UBatchSlice(
padded_second_ubatch_slice, padded_second_ubatch_slice
)
def _synchronize_dp_ranks( def _synchronize_dp_ranks(
num_tokens_unpadded: int, num_tokens_unpadded: int,
num_tokens_padded: int, num_tokens_padded: int,
@ -220,11 +232,14 @@ def coordinate_batch_across_dp(
# to the second ubatch in pad_out_ubatch_slice after attention # to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation # metadata creation
assert num_tokens_after_padding is not None assert num_tokens_after_padding is not None
token_split_point = int(num_tokens_after_padding[0].item()) // 2 num_tokens_padded = int(num_tokens_after_padding[0].item())
token_split_point = int(num_tokens_padded) // 2
assert num_scheduled_tokens_per_request is not None assert num_scheduled_tokens_per_request is not None
ubatch_slices = create_ubatch_slices( ubatch_slices = create_ubatch_slices(
num_scheduled_tokens_per_request, token_split_point num_scheduled_tokens_per_request, token_split_point
) )
ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded)
assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded
return (ubatch_slices, num_tokens_after_padding) return (ubatch_slices, num_tokens_after_padding)

View File

@ -151,7 +151,6 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatch_utils import ( from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
UBatchSlices, UBatchSlices,
check_ubatch_thresholds, check_ubatch_thresholds,
) )
@ -1239,17 +1238,13 @@ class GPUModelRunner(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
max_num_scheduled_tokens: int,
) -> tuple[ ) -> tuple[
torch.Tensor, torch.Tensor,
SpecDecodeMetadata | None, SpecDecodeMetadata | None,
UBatchSlices | None,
torch.Tensor | None,
]: ]:
""" """
:return: tuple[ :return: tuple[
logits_indices, spec_decode_metadata, logits_indices, spec_decode_metadata,
ubatch_slices, num_tokens_across_dp,
] ]
""" """
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -1364,28 +1359,6 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu() self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded)
uniform_decode = (
max_num_scheduled_tokens == self.uniform_decode_query_len
) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set enforce_eager on the prefiller in
# a P/D setup and still use CUDA graphs (enabled by this padding) on the
# decoder.
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.parallel_config,
allow_microbatching=True,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens,
)
self.seq_lens.np[:num_reqs] = ( self.seq_lens.np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
) )
@ -1486,15 +1459,15 @@ class GPUModelRunner(
return ( return (
logits_indices, logits_indices,
spec_decode_metadata, spec_decode_metadata,
ubatch_slices,
num_tokens_across_dp,
) )
def _build_attention_metadata( def _build_attention_metadata(
self, self,
total_num_scheduled_tokens: int, num_tokens: int,
max_num_scheduled_tokens: int,
num_reqs: int, num_reqs: int,
max_query_len: int,
num_tokens_padded: int | None = None,
num_reqs_padded: int | None = None,
ubatch_slices: UBatchSlices | None = None, ubatch_slices: UBatchSlices | None = None,
logits_indices: torch.Tensor | None = None, logits_indices: torch.Tensor | None = None,
use_spec_decode: bool = False, use_spec_decode: bool = False,
@ -1505,6 +1478,9 @@ class GPUModelRunner(
""" """
:return: tuple[attn_metadata, spec_decode_common_attn_metadata] :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
""" """
num_tokens_padded = num_tokens_padded or num_tokens
num_reqs_padded = num_reqs_padded or num_reqs
logits_indices_padded = None logits_indices_padded = None
num_logits_indices = None num_logits_indices = None
if logits_indices is not None: if logits_indices is not None:
@ -1522,28 +1498,13 @@ class GPUModelRunner(
self.dcp_rank, self.dcp_rank,
self.parallel_config.cp_kv_cache_interleave_size, self.parallel_config.cp_kv_cache_interleave_size,
) )
self.dcp_local_seq_lens.copy_to_gpu(num_reqs) self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)
attn_metadata: PerLayerAttnMetadata = {} attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None: if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))] attn_metadata = [dict() for _ in range(len(ubatch_slices))]
# Used in the below loop
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1]
seq_lens = self.seq_lens.gpu[:num_reqs]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs
]
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
if self.dcp_world_size > 1:
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs]
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs]
spec_decode_common_attn_metadata = None
if for_cudagraph_capture: if for_cudagraph_capture:
# For some attention backends (e.g. FA) with sliding window models we need # For some attention backends (e.g. FA) with sliding window models we need
# to make sure the backend see a max_seq_len that is larger to the sliding # to make sure the backend see a max_seq_len that is larger to the sliding
@ -1559,6 +1520,22 @@ class GPUModelRunner(
self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu() self.num_accepted_tokens.copy_to_gpu()
# Used in the below loop, uses padded shapes
query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1]
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
seq_lens = self.seq_lens.gpu[:num_reqs_padded]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs_padded]
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
]
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
if self.dcp_world_size > 1:
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs_padded]
spec_decode_common_attn_metadata = None
# Prepare the attention metadata for each KV cache group and make layers # Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata. # in the same group share the same metadata.
for kv_cache_gid, kv_cache_group in enumerate( for kv_cache_gid, kv_cache_group in enumerate(
@ -1567,30 +1544,31 @@ class GPUModelRunner(
encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens( encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
num_scheduled_tokens or {}, num_scheduled_tokens or {},
kv_cache_group.kv_cache_spec, kv_cache_group.kv_cache_spec,
num_reqs, num_reqs_padded,
) )
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to # Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them. # create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros( blk_table_tensor = torch.zeros(
(num_reqs, 1), (num_tokens_padded, 1),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
slot_mapping = torch.zeros( slot_mapping = torch.zeros(
(total_num_scheduled_tokens,), (num_tokens_padded,),
dtype=torch.int64, dtype=torch.int64,
device=self.device, device=self.device,
) )
else: else:
blk_table = self.input_batch.block_table[kv_cache_gid] blk_table = self.input_batch.block_table[kv_cache_gid]
blk_table_tensor = blk_table.get_device_tensor(num_reqs) blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
# Fill unused with -1. Needed for reshape_and_cache in full cuda # Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode. # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
@ -1598,9 +1576,9 @@ class GPUModelRunner(
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu, seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu, num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs, num_actual_tokens=num_tokens_padded,
num_actual_tokens=total_num_scheduled_tokens, num_reqs=num_reqs_padded,
max_query_len=max_num_scheduled_tokens, max_query_len=max_query_len,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
block_table_tensor=blk_table_tensor, block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
@ -1631,9 +1609,11 @@ class GPUModelRunner(
extra_attn_metadata_args = {} extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
extra_attn_metadata_args = dict( extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], num_accepted_tokens=self.num_accepted_tokens.gpu[
:num_reqs_padded
],
num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
:num_reqs :num_reqs_padded
], ],
) )
@ -1677,6 +1657,7 @@ class GPUModelRunner(
def _compute_cascade_attn_prefix_lens( def _compute_cascade_attn_prefix_lens(
self, self,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
num_computed_tokens: np.ndarray,
num_common_prefix_blocks: list[int], num_common_prefix_blocks: list[int],
) -> list[list[int]] | None: ) -> list[list[int]] | None:
""" """
@ -1699,6 +1680,7 @@ class GPUModelRunner(
# 0 if cascade attention should not be used # 0 if cascade attention should not be used
cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len( cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens, num_scheduled_tokens,
num_computed_tokens,
num_common_prefix_blocks[kv_cache_gid], num_common_prefix_blocks[kv_cache_gid],
attn_group.kv_cache_spec, attn_group.kv_cache_spec,
attn_group.get_metadata_builder(), attn_group.get_metadata_builder(),
@ -1711,6 +1693,7 @@ class GPUModelRunner(
def _compute_cascade_attn_prefix_len( def _compute_cascade_attn_prefix_len(
self, self,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
num_computed_tokens: np.ndarray,
num_common_prefix_blocks: int, num_common_prefix_blocks: int,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
attn_metadata_builder: AttentionMetadataBuilder, attn_metadata_builder: AttentionMetadataBuilder,
@ -1777,10 +1760,7 @@ class GPUModelRunner(
# and the second kernel will get an empty input. While this is not # and the second kernel will get an empty input. While this is not
# a fundamental problem, our current implementation does not support # a fundamental problem, our current implementation does not support
# this case. # this case.
num_reqs = len(num_scheduled_tokens) common_prefix_len = min(common_prefix_len, num_computed_tokens.min())
common_prefix_len = min(
common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()
)
# common_prefix_len should be a multiple of the block size. # common_prefix_len should be a multiple of the block size.
common_prefix_len = ( common_prefix_len = (
common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size
@ -2334,19 +2314,6 @@ class GPUModelRunner(
log_stats=self.parallel_config.eplb_config.log_balancedness, log_stats=self.parallel_config.eplb_config.log_balancedness,
) )
# This is where the second ubatch is adjusted to account for the padding.
# Should be called after attention metadata creation. This just pads
# the second ubatch slice out to the total number of tokens
# (num_tokens + padding)
@staticmethod
def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
padded_second_ubatch_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
ubatch_slices[1] = UBatchSlice(
padded_second_ubatch_slice, padded_second_ubatch_slice
)
def _pool( def _pool(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -2391,18 +2358,7 @@ class GPUModelRunner(
pooler_output=pooler_output, pooler_output=pooler_output,
) )
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
if (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and hasattr(self, "cudagraph_batch_sizes")
and self.cudagraph_batch_sizes
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]
):
# Use CUDA graphs.
# Add padding to the batch size.
return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens)
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when # Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP # enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size tp_size = self.vllm_config.parallel_config.tensor_parallel_size
@ -2738,6 +2694,87 @@ class GPUModelRunner(
**model_kwargs, **model_kwargs,
) )
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
num_reqs: int,
num_scheduled_tokens_np: np.ndarray,
max_num_scheduled_tokens: int,
use_cascade_attn: bool,
allow_microbatching: bool = True,
force_eager: bool = False,
# For cudagraph capture TODO(lucas): Refactor how we capture cudagraphs (will
# be improved in model runner v2)
force_uniform_decode: bool | None = None,
force_has_lora: bool | None = None,
) -> tuple[
CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None
]:
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
uniform_decode = (
(
(max_num_scheduled_tokens == self.uniform_decode_query_len)
and (num_tokens_padded == max_num_scheduled_tokens * num_reqs)
)
if force_uniform_decode is None
else force_uniform_decode
)
has_lora = (
len(self.input_batch.lora_id_to_lora_request) > 0
if force_has_lora is None
else force_has_lora
)
dispatch_cudagraph = (
lambda num_tokens: self.cudagraph_dispatcher.dispatch(
num_tokens=num_tokens,
has_lora=has_lora,
use_cascade_attn=use_cascade_attn,
uniform_decode=uniform_decode,
)
if not force_eager
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
)
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
num_tokens_padded = batch_descriptor.num_tokens
# Extra coordination when running data-parallel since we need to coordinate
# across ranks
ubatch_slices, num_tokens_across_dp = None, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
# in a P/D setup and still use CUDA graphs (enabled by this padding) on the
# decoder.
allow_dp_padding = (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
)
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_padded,
parallel_config=self.parallel_config,
allow_microbatching=allow_microbatching,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
)
# Extract DP padding if there is any
if num_tokens_across_dp is not None:
dp_rank = self.parallel_config.data_parallel_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
# Assert to make sure the agreed upon token count is correct otherwise
# num_tokens_across_dp will no-longer be valid
assert batch_descriptor.num_tokens == num_tokens_padded
return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
@ -2790,7 +2827,7 @@ class GPUModelRunner(
# returns True. before returning early here we call # returns True. before returning early here we call
# dummy run to ensure coordinate_batch_across_dp # dummy run to ensure coordinate_batch_across_dp
# is called into to avoid out of sync issues. # is called into to avoid out of sync issues.
self._dummy_run(self._get_num_input_tokens(1)) self._dummy_run(1)
if not has_kv_transfer_group(): if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do. # Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT return EMPTY_MODEL_RUNNER_OUTPUT
@ -2809,36 +2846,63 @@ class GPUModelRunner(
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
( (
logits_indices, logits_indices,
spec_decode_metadata, spec_decode_metadata,
ubatch_slices,
num_tokens_across_dp,
) = self._prepare_inputs( ) = self._prepare_inputs(
scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens scheduler_output,
num_scheduled_tokens_np,
) )
cascade_attn_prefix_lens = None cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO) # Disable cascade attention when using microbatching (DBO)
if self.cascade_attn_enabled and ubatch_slices is None: if self.cascade_attn_enabled and not self.parallel_config.enable_dbo:
# Pre-compute cascade attention prefix lengths # Pre-compute cascade attention prefix lengths
# NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
num_scheduled_tokens_np, num_scheduled_tokens_np,
self.input_batch.num_computed_tokens_cpu[:num_reqs],
scheduler_output.num_common_prefix_blocks, scheduler_output.num_common_prefix_blocks,
) )
# TODO(lucas): move cudagraph dispatching here: (
# https://github.com/vllm-project/vllm/issues/23789 cudagraph_mode,
batch_desc,
ubatch_slices,
num_tokens_across_dp,
) = self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs,
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=max_num_scheduled_tokens,
use_cascade_attn=cascade_attn_prefix_lens is not None,
)
logger.debug(
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
"ubatch_slices: %s, num_tokens_across_dp: %s",
cudagraph_mode,
batch_desc,
ubatch_slices,
num_tokens_across_dp,
)
num_tokens_padded = batch_desc.num_tokens
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
attn_metadata, spec_decode_common_attn_metadata = ( pad_attn = cudagraph_mode == CUDAGraphMode.FULL
(attn_metadata, spec_decode_common_attn_metadata) = (
self._build_attention_metadata( self._build_attention_metadata(
total_num_scheduled_tokens=total_num_scheduled_tokens, num_tokens=num_tokens_unpadded,
max_num_scheduled_tokens=max_num_scheduled_tokens, num_tokens_padded=num_tokens_padded if pad_attn else None,
num_reqs=num_reqs, num_reqs=num_reqs,
num_reqs_padded=num_reqs_padded if pad_attn else None,
max_query_len=max_num_scheduled_tokens,
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
logits_indices=logits_indices, logits_indices=logits_indices,
use_spec_decode=use_spec_decode, use_spec_decode=use_spec_decode,
@ -2847,49 +2911,22 @@ class GPUModelRunner(
) )
) )
dp_rank = self.parallel_config.data_parallel_rank (
if ubatch_slices: input_ids,
assert num_tokens_across_dp is not None inputs_embeds,
num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) positions,
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) intermediate_tensors,
elif num_tokens_across_dp is not None: model_kwargs,
num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) ec_connector_output,
else: ) = self._preprocess(
num_input_tokens = self._get_num_input_tokens( scheduler_output, num_tokens_padded, intermediate_tensors
scheduler_output.total_num_scheduled_tokens
)
(
input_ids,
inputs_embeds,
positions,
intermediate_tensors,
model_kwargs,
ec_connector_output,
) = self._preprocess(
scheduler_output, num_input_tokens, intermediate_tensors
)
uniform_decode = (
max_num_scheduled_tokens == self.uniform_decode_query_len
) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
batch_desc = BatchDescriptor(
num_tokens=num_input_tokens,
uniform_decode=uniform_decode,
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
)
cudagraph_runtime_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(
batch_desc,
use_cascade_attn=cascade_attn_prefix_lens is not None,
)
) )
# Set cudagraph mode to none if calc_kv_scales is true. # Set cudagraph mode to none if calc_kv_scales is true.
# KV scales calculation involves dynamic operations that are incompatible # KV scales calculation involves dynamic operations that are incompatible
# with CUDA graph capture. # with CUDA graph capture.
if self.calculate_kv_scales: if self.calculate_kv_scales:
cudagraph_runtime_mode = CUDAGraphMode.NONE cudagraph_mode = CUDAGraphMode.NONE
# Mark KV scales as calculated after the first forward pass # Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False self.calculate_kv_scales = False
@ -2899,10 +2936,10 @@ class GPUModelRunner(
set_forward_context( set_forward_context(
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_tokens_padded,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
), ),
record_function_or_nullcontext("gpu_model_runner: forward"), record_function_or_nullcontext("gpu_model_runner: forward"),
@ -2952,7 +2989,7 @@ class GPUModelRunner(
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
all_gather_tensors = { all_gather_tensors = {
"residual": not is_residual_scattered_for_sp( "residual": not is_residual_scattered_for_sp(
self.vllm_config, num_input_tokens self.vllm_config, num_tokens_padded
) )
} }
get_pp_group().send_tensor_dict( get_pp_group().send_tensor_dict(
@ -3841,52 +3878,44 @@ class GPUModelRunner(
assert sum(num_scheduled_tokens_list) == num_tokens assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) num_tokens_unpadded = int(num_scheduled_tokens.sum())
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
# Disable DP padding when running eager _cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = (
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
# We currently only microbatch if the number of tokens is num_reqs=num_reqs,
# over a certain threshold. num_scheduled_tokens_np=num_scheduled_tokens,
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( max_num_scheduled_tokens=max_query_len,
num_tokens_unpadded=total_num_scheduled_tokens, use_cascade_attn=False,
parallel_config=self.vllm_config.parallel_config, allow_microbatching=allow_microbatching,
allow_microbatching=allow_microbatching, force_eager=is_profile
allow_dp_padding=allow_dp_padding, or (cudagraph_runtime_mode == CUDAGraphMode.NONE),
num_tokens_padded=total_num_scheduled_tokens, # `force_uniform_decode` is used for cudagraph capture; because for
uniform_decode=uniform_decode, # capturing mixed prefill-decode batches, we sometimes use
num_scheduled_tokens_per_request=num_scheduled_tokens, # num_tokens == num_reqs which looks like a uniform decode batch to the
) # dispatcher; but we actually want to capture a piecewise cudagraph
num_tokens_after_padding = num_tokens force_uniform_decode=uniform_decode,
if num_tokens_across_dp is not None: # `force_has_lora` is used for cudagraph capture; because LoRA is
dp_rank = self.parallel_config.data_parallel_rank # activated later in the context manager, but we need to know the
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank]) # LoRA state when determining the batch descriptor for capture
force_has_lora=activate_lora,
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(
num_tokens=num_tokens_after_padding,
uniform_decode=uniform_decode,
has_lora=activate_lora and self.lora_config is not None,
)
) )
if not is_profile
else (CUDAGraphMode.NONE, None)
) )
if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support if cudagraph_runtime_mode is None:
# warm ups for cudagraph capture cudagraph_runtime_mode = _cudagraph_mode
assert (
cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode == _cg_mode
), (
f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
)
else: else:
cudagraph_runtime_mode = _cg_mode assert cudagraph_runtime_mode == _cudagraph_mode, (
f"Cudagraph runtime mode mismatch in dummy_run. "
f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}."
)
num_tokens_padded = batch_desc.num_tokens
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
attn_metadata: PerLayerAttnMetadata | None = None attn_metadata: PerLayerAttnMetadata | None = None
@ -3909,9 +3938,9 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu() self.query_start_loc.copy_to_gpu()
attn_metadata, _ = self._build_attention_metadata( attn_metadata, _ = self._build_attention_metadata(
total_num_scheduled_tokens=num_tokens, num_tokens=num_tokens_unpadded,
max_num_scheduled_tokens=max_query_len, num_reqs=num_reqs_padded,
num_reqs=num_reqs, max_query_len=max_query_len,
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
for_cudagraph_capture=True, for_cudagraph_capture=True,
) )
@ -3924,29 +3953,29 @@ class GPUModelRunner(
remove_lora, remove_lora,
): ):
# Make sure padding doesn't exceed max_num_tokens # Make sure padding doesn't exceed max_num_tokens
assert num_tokens_after_padding <= self.max_num_tokens assert num_tokens_padded <= self.max_num_tokens
model_kwargs = self._init_model_kwargs(num_tokens_after_padding) model_kwargs = self._init_model_kwargs(num_tokens_padded)
if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding] inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
model_kwargs = { model_kwargs = {
**model_kwargs, **model_kwargs,
**self._dummy_mm_kwargs(num_reqs), **self._dummy_mm_kwargs(num_reqs),
} }
elif self.enable_prompt_embeds: elif self.enable_prompt_embeds:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding] inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
model_kwargs = self._init_model_kwargs(num_tokens_after_padding) model_kwargs = self._init_model_kwargs(num_tokens_padded)
else: else:
input_ids = self.input_ids.gpu[:num_tokens_after_padding] input_ids = self.input_ids.gpu[:num_tokens_padded]
inputs_embeds = None inputs_embeds = None
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens_after_padding] positions = self.mrope_positions.gpu[:, :num_tokens_padded]
elif self.uses_xdrope_dim > 0: elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_tokens_after_padding] positions = self.xdrope_positions.gpu[:, :num_tokens_padded]
else: else:
positions = self.positions.gpu[:num_tokens_after_padding] positions = self.positions.gpu[:num_tokens_padded]
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
intermediate_tensors = None intermediate_tensors = None
@ -3961,26 +3990,26 @@ class GPUModelRunner(
) )
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens_after_padding, None, False num_tokens_padded, None, False
) )
if ubatch_slices is not None: if ubatch_slices is not None:
# Adjust values to reflect a single ubatch. # Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in # TODO(sage,lucas): this is cruft that should be addressed in
# the padding refactor. # the padding refactor.
num_tokens_after_padding = ubatch_slices[0].num_tokens num_tokens_padded = ubatch_slices[0].num_tokens
if num_tokens_across_dp is not None: if num_tokens_across_dp is not None:
num_tokens_across_dp[:] = num_tokens_after_padding num_tokens_across_dp[:] = num_tokens_padded
with ( with (
self.maybe_randomize_inputs(input_ids), self.maybe_randomize_inputs(input_ids),
set_forward_context( set_forward_context(
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens_after_padding, num_tokens=num_tokens_padded,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
), ),
): ):
@ -4706,8 +4735,7 @@ class GPUModelRunner(
# Trigger cudagraph dispatching keys initialization after # Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode. # resolved cudagraph mode.
cudagraph_mode = self.compilation_config.cudagraph_mode self.compilation_config.cudagraph_mode = cudagraph_mode
assert cudagraph_mode is not None
self.cudagraph_dispatcher.initialize_cudagraph_keys( self.cudagraph_dispatcher.initialize_cudagraph_keys(
cudagraph_mode, self.uniform_decode_query_len cudagraph_mode, self.uniform_decode_query_len
) )

View File

@ -8,12 +8,13 @@ from contextlib import AbstractContextManager, nullcontext
from types import NoneType from types import NoneType
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
import numpy as np
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
@ -487,6 +488,7 @@ class Worker(WorkerBase):
hidden_states, last_hidden_states = self.model_runner._dummy_run( hidden_states, last_hidden_states = self.model_runner._dummy_run(
num_tokens=max_num_reqs, num_tokens=max_num_reqs,
skip_eplb=True, skip_eplb=True,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
) )
if self.model_runner.is_pooling_model: if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states) self.model_runner._dummy_pooler_run(hidden_states)
@ -534,12 +536,39 @@ class Worker(WorkerBase):
intermediate_tensors = None intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0 forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens) all_gather_tensors = {}
all_gather_tensors = { compilation_config = self.vllm_config.compilation_config
"residual": not is_residual_scattered_for_sp( parallel_config = self.vllm_config.parallel_config
self.vllm_config, num_input_tokens
if (
parallel_config.pipeline_parallel_size > 1
and compilation_config.pass_config.enable_sequence_parallelism
and forward_pass
):
# currently only supported by V1 GPUModelRunner
assert isinstance(self.model_runner, GPUModelRunner)
num_scheduled_tokens_np = np.array(
list(scheduler_output.num_scheduled_tokens.values()),
dtype=np.int32,
) )
} # TODO(lucas): This is pretty gross; ideally we should only ever call
# `_determine_batch_execution_and_padding` once (will get called again
# in `execute_model`) but this requires a larger refactor of PP.
_, batch_desc, _, _ = (
self.model_runner._determine_batch_execution_and_padding(
num_tokens=num_scheduled_tokens,
num_reqs=len(num_scheduled_tokens_np),
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
use_cascade_attn=False, # TODO(lucas): Handle cascade attention
)
)
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
self.vllm_config, batch_desc.num_tokens
)
}
if forward_pass and not get_pp_group().is_first_rank: if forward_pass and not get_pp_group().is_first_rank:
tensor_dict = get_pp_group().recv_tensor_dict( tensor_dict = get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(), all_gather_group=get_tp_group(),