vllm/vllm/v1/cudagraph_dispatcher.py
Yu Gong ea1a26d276 Capture multiple cuda graph across various active loras
Signed-off-by: Yu Gong <yu3.gong@gmail.com>
2025-12-23 19:29:16 +00:00

240 lines
9.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
logger = init_logger(__name__)
class CudagraphDispatcher:
"""
Runtime cudagraph dispatcher to dispatch keys for multiple set of
cudagraphs.
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
for FULL cudagraph runtime mode. The keys are initialized depending on
attention support and what cudagraph mode is set in CompilationConfig. The
keys stored in dispatcher are the only source of truth for valid
cudagraphs that can be dispatched at runtime.
At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
based on the input key. After dispatching (communicated via forward
context), the cudagraph wrappers will trust the dispatch key to either
capture or replay (if the mode matches), or pass through to the underlying
runnable without cudagraph (if the mode does not match or mode is NONE).
"""
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.uniform_decode_query_len = (
1
if not self.vllm_config.speculative_config
else 1 + self.vllm_config.speculative_config.num_speculative_tokens
)
# Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
CUDAGraphMode.PIECEWISE: set(),
CUDAGraphMode.FULL: set(),
}
assert (
not self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
or self.compilation_config.is_attention_compiled_piecewise()
), (
"Compilation mode should be CompilationMode.VLLM_COMPILE when "
"cudagraph_mode piecewise cudagraphs is used, "
"and attention should be in splitting_ops or "
"inductor splitting should be used. "
f"cudagraph_mode={self.compilation_config.cudagraph_mode}, "
f"compilation_mode={self.compilation_config.mode}, "
f"splitting_ops={self.compilation_config.splitting_ops}"
)
self.keys_initialized = False
self.specialize_lora_count = False
def _create_padded_batch_descriptor(
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
num_active_loras: int = 0,
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len
assert num_tokens_padded % uniform_decode_query_len == 0
else:
uniform_decode = False
num_reqs = min(num_tokens_padded, max_num_seqs)
return BatchDescriptor(
num_tokens=num_tokens_padded,
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
num_active_loras=num_active_loras,
)
def add_cudagraph_key(
self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
):
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid cudagraph runtime mode for keys: {runtime_mode}"
)
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
def _get_lora_cases(self) -> list[tuple[bool, int]]:
"""
Returns list of (has_lora, num_active_loras) tuples for graph capture.
Returns cases for each num_active_loras from 1 to max_loras.
If cudagraph_specialize_lora is True, also includes the no-lora case.
Note: When speculative decoding is enabled, we fall back to capturing
only with max_loras to avoid conflicts with torch.compile during
CUDA graph capture.
"""
if not self.vllm_config.lora_config:
return [(False, 0)]
max_loras = self.vllm_config.lora_config.max_loras
# When speculative decoding is enabled, only capture with max_loras
# to avoid torch.compile conflicts during CUDA graph capture
if self.vllm_config.speculative_config is not None:
lora_cases = [(True, max_loras)]
if self.compilation_config.cudagraph_specialize_lora:
lora_cases.append((False, 0))
return lora_cases
# Capture for each num_active_loras from 1 to max_loras
lora_cases = [(True, n) for n in range(1, max_loras + 1)]
# Also capture the no-lora case
if self.compilation_config.cudagraph_specialize_lora:
lora_cases.append((False, 0))
return lora_cases
def initialize_cudagraph_keys(
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
):
# This should be called only after attention backend is initialized. So we can
# get the correct cudagraph mode after backend support is resolved.
self.cudagraph_mode = cudagraph_mode
# Track whether we have LoRA config (always specialize on count)
self.has_lora_config = self.vllm_config.lora_config is not None
# Get LoRA cases to capture
lora_cases = self._get_lora_cases()
# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs, (has_lora, num_active_loras) in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor(
bs, False, has_lora, num_active_loras
).relax_for_mixed_batch_cudagraphs(),
)
# if decode cudagraph mode is FULL, and we don't already have mixed
# mode full cudagraphs then add them here.
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and cudagraph_mode.separate_routine()
):
max_num_tokens = (
uniform_decode_query_len
* self.vllm_config.scheduler_config.max_num_seqs
)
cudagraph_capture_sizes_for_decode = [
x
for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs, (has_lora, num_active_loras) in product(
cudagraph_capture_sizes_for_decode, lora_cases
):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
self._create_padded_batch_descriptor(
bs, True, has_lora, num_active_loras
),
)
self.keys_initialized = True
def dispatch(
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
disable_full: bool = False,
num_active_loras: int = 0,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
Args:
num_tokens: Number of tokens in the batch.
uniform_decode: Whether this is a uniform decode batch.
has_lora: Whether this batch has active LoRA adapters.
disable_full: Whether to disable full cudagraph mode.
num_active_loras: Number of distinct active LoRA adapters.
"""
if (
not self.keys_initialized
or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
# When speculative decoding is enabled, always use max_loras for lookup
# since we only capture graphs with max_loras
effective_num_active_loras = num_active_loras
if (
self.vllm_config.speculative_config is not None
and self.vllm_config.lora_config is not None
and has_lora
):
effective_num_active_loras = self.vllm_config.lora_config.max_loras
batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora, effective_num_active_loras
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
if not disable_full:
# check if key exists for full cudagraph
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc
# otherwise, check if the relaxed key exists
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, relaxed_batch_desc
# also check if the relaxed key exists for more "general"
# piecewise cudagraph
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
# finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)