Capture multiple cuda graph across various active loras

Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
Yu Gong 2025-12-18 19:16:02 +00:00
parent 288d67d054
commit ea1a26d276
9 changed files with 216 additions and 41 deletions

View File

@ -46,6 +46,14 @@ class BatchDescriptor(NamedTuple):
"""
Whether this batch has active LoRA adapters.
"""
num_active_loras: int = 0
"""
Number of distinct active LoRA adapters in this batch.
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
are captured for each num_active_loras value. This allows kernels
(like fused_moe_lora) whose grid size depends on num_active_loras
to be properly captured.
"""
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
@ -53,7 +61,11 @@ class BatchDescriptor(NamedTuple):
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
self.num_tokens,
num_reqs=None,
uniform=False,
has_lora=self.has_lora,
num_active_loras=self.num_active_loras,
)

View File

@ -1,17 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
from .utils import supports_pdl
logger = init_logger(__name__)
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
@ -413,6 +417,7 @@ def _fused_moe_lora(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
@ -562,6 +567,7 @@ def _fused_moe_lora_fake(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,

View File

@ -140,6 +140,7 @@ def _lora_expand(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
@ -291,6 +292,7 @@ def _lora_expand_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:

View File

@ -28,6 +28,10 @@ class LoRAKernelMeta:
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
# Stored as a Python int to avoid GPU->CPU sync during forward pass
num_active_loras: int = 0
@staticmethod
def make(
max_loras: int, max_num_tokens: int, device: torch.device | str
@ -73,6 +77,7 @@ class LoRAKernelMeta:
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
self.num_active_loras = 0
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
@ -117,6 +122,9 @@ class LoRAKernelMeta:
self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_(
num_tokens_per_lora, non_blocking=True
)
# Store num_active_loras as Python int (excludes -1 which means no LoRA)
# Valid LoRA IDs are >= 0, so count those
self.num_active_loras = int((lora_ids >= 0).sum().item())
# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
@ -133,6 +141,7 @@ class LoRAKernelMeta:
torch.Tensor,
torch.Tensor,
torch.Tensor,
int,
]:
"""
This function returns the kernel metadata required for the current
@ -151,4 +160,5 @@ class LoRAKernelMeta:
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
self.num_active_loras,
)

View File

@ -136,6 +136,7 @@ def _lora_shrink(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
scaling: float,
) -> None:
"""
@ -269,6 +270,7 @@ def _lora_shrink_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
scaling: float,
) -> None:
return

View File

@ -333,8 +333,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
(max_loras), dtype=torch.int32, device=topk_ids.device
)
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
num_tokens
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
self.token_mapping_meta.meta_args(num_tokens)
)
ops.moe_lora_align_block_size(
@ -378,7 +378,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
"""
(_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
(_, _, _, _, lora_ids, _, num_active_loras) = self.token_mapping_meta.meta_args(
x.size(0)
)
fused_moe_lora(
y,
x,
@ -391,6 +393,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
shrink_config.get("BLOCK_SIZE_M", 64),
shrink_config.get("BLOCK_SIZE_N", 64),

View File

@ -57,9 +57,14 @@ class CudagraphDispatcher:
)
self.keys_initialized = False
self.specialize_lora_count = False
def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
num_active_loras: int = 0,
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
@ -77,6 +82,7 @@ class CudagraphDispatcher:
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
num_active_loras=num_active_loras,
)
def add_cudagraph_key(
@ -87,6 +93,37 @@ class CudagraphDispatcher:
)
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
def _get_lora_cases(self) -> list[tuple[bool, int]]:
"""
Returns list of (has_lora, num_active_loras) tuples for graph capture.
Returns cases for each num_active_loras from 1 to max_loras.
If cudagraph_specialize_lora is True, also includes the no-lora case.
Note: When speculative decoding is enabled, we fall back to capturing
only with max_loras to avoid conflicts with torch.compile during
CUDA graph capture.
"""
if not self.vllm_config.lora_config:
return [(False, 0)]
max_loras = self.vllm_config.lora_config.max_loras
# When speculative decoding is enabled, only capture with max_loras
# to avoid torch.compile conflicts during CUDA graph capture
if self.vllm_config.speculative_config is not None:
lora_cases = [(True, max_loras)]
if self.compilation_config.cudagraph_specialize_lora:
lora_cases.append((False, 0))
return lora_cases
# Capture for each num_active_loras from 1 to max_loras
lora_cases = [(True, n) for n in range(1, max_loras + 1)]
# Also capture the no-lora case
if self.compilation_config.cudagraph_specialize_lora:
lora_cases.append((False, 0))
return lora_cases
def initialize_cudagraph_keys(
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
):
@ -94,26 +131,23 @@ class CudagraphDispatcher:
# get the correct cudagraph mode after backend support is resolved.
self.cudagraph_mode = cudagraph_mode
# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]
# Track whether we have LoRA config (always specialize on count)
self.has_lora_config = self.vllm_config.lora_config is not None
# Get LoRA cases to capture
lora_cases = self._get_lora_cases()
# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs, has_lora in product(
for bs, (has_lora, num_active_loras) in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor(
bs, False, has_lora
bs, False, has_lora, num_active_loras
).relax_for_mixed_batch_cudagraphs(),
)
@ -132,10 +166,14 @@ class CudagraphDispatcher:
for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
for bs, (has_lora, num_active_loras) in product(
cudagraph_capture_sizes_for_decode, lora_cases
):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
self._create_padded_batch_descriptor(bs, True, has_lora),
self._create_padded_batch_descriptor(
bs, True, has_lora, num_active_loras
),
)
self.keys_initialized = True
@ -146,12 +184,20 @@ class CudagraphDispatcher:
uniform_decode: bool,
has_lora: bool,
disable_full: bool = False,
num_active_loras: int = 0,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
Args:
num_tokens: Number of tokens in the batch.
uniform_decode: Whether this is a uniform decode batch.
has_lora: Whether this batch has active LoRA adapters.
disable_full: Whether to disable full cudagraph mode.
num_active_loras: Number of distinct active LoRA adapters.
"""
if (
not self.keys_initialized
@ -160,8 +206,18 @@ class CudagraphDispatcher:
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
# When speculative decoding is enabled, always use max_loras for lookup
# since we only capture graphs with max_loras
effective_num_active_loras = num_active_loras
if (
self.vllm_config.speculative_config is not None
and self.vllm_config.lora_config is not None
and has_lora
):
effective_num_active_loras = self.vllm_config.lora_config.max_loras
batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora
num_tokens, uniform_decode, has_lora, effective_num_active_loras
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()

View File

@ -2820,6 +2820,7 @@ class GPUModelRunner(
# be improved in model runner v2)
force_uniform_decode: bool | None = None,
force_has_lora: bool | None = None,
force_num_active_loras: int | None = None,
num_encoder_reqs: int = 0,
) -> tuple[
CUDAGraphMode,
@ -2841,11 +2842,13 @@ class GPUModelRunner(
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
)
has_lora = (
len(self.input_batch.lora_id_to_lora_request) > 0
if force_has_lora is None
else force_has_lora
# Compute LoRA state for cudagraph dispatch
num_active_loras = (
force_num_active_loras
if force_num_active_loras is not None
else len(self.input_batch.lora_id_to_lora_request)
)
has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
dispatch_cudagraph = (
@ -2854,6 +2857,7 @@ class GPUModelRunner(
has_lora=has_lora,
uniform_decode=uniform_decode,
disable_full=disable_full,
num_active_loras=num_active_loras,
)
if not force_eager
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
@ -4058,6 +4062,7 @@ class GPUModelRunner(
remove_lora: bool = True,
activate_lora: bool = False,
is_graph_capturing: bool = False,
num_active_loras: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
@ -4081,6 +4086,8 @@ class GPUModelRunner(
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
activate_lora: If False, dummy_run is performed without LoRAs.
num_active_loras: Number of distinct active LoRAs to capture for.
Only used when cudagraph_specialize_lora_count is True.
"""
if supports_mm_encoder_only(self.model):
# The current dummy run only covers LM execution, so we can skip it.
@ -4162,6 +4169,9 @@ class GPUModelRunner(
# activated later in the context manager, but we need to know the
# LoRA state when determining the batch descriptor for capture
force_has_lora=activate_lora,
# `force_num_active_loras` is used for cudagraph capture; because we
# need to capture graphs for specific num_active_loras counts
force_num_active_loras=num_active_loras if activate_lora else 0,
)
)
@ -4225,6 +4235,7 @@ class GPUModelRunner(
num_sampled_tokens,
activate_lora,
remove_lora,
num_active_loras,
):
# Make sure padding doesn't exceed max_num_tokens
assert num_tokens_padded <= self.max_num_tokens
@ -4587,6 +4598,24 @@ class GPUModelRunner(
self.encoder_cache.clear()
gc.collect()
def _get_lora_capture_cases(self) -> list[tuple[bool, int]]:
"""
Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture.
Returns cases for each num_active_loras from 1 to max_loras.
If cudagraph_specialize_lora is True, also includes the no-lora case.
"""
if not self.lora_config:
return [(False, 0)]
max_loras = self.lora_config.max_loras
# Capture for each num_active_loras from 1 to max_loras
lora_cases = [(True, n) for n in range(1, max_loras + 1)]
# Also capture the no-lora case if cudagraph_specialize_lora is True
if self.compilation_config.cudagraph_specialize_lora:
lora_cases.append((False, 0))
return lora_cases
def capture_model(self) -> int:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
logger.warning(
@ -4624,13 +4653,8 @@ class GPUModelRunner(
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
if self.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]
# Build LoRA cases: list of (has_lora, num_active_loras) tuples
lora_cases = self._get_lora_capture_cases()
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
@ -4695,10 +4719,23 @@ class GPUModelRunner(
def _capture_cudagraphs(
self,
compilation_cases: list[tuple[int, bool]],
compilation_cases: list[tuple[int, tuple[bool, int]]],
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool,
):
"""
Capture CUDA graphs for the given compilation cases.
Args:
compilation_cases: List of (num_tokens, (has_lora, num_active_loras))
tuples.
- num_tokens: batch size to capture
- has_lora: whether LoRA is active for this capture
- num_active_loras: number of distinct active LoRAs
(0 if not specializing)
cudagraph_runtime_mode: FULL or PIECEWISE cudagraph mode.
uniform_decode: Whether this is a uniform decode batch.
"""
assert (
cudagraph_runtime_mode != CUDAGraphMode.NONE
and cudagraph_runtime_mode.valid_runtime_modes()
@ -4716,7 +4753,7 @@ class GPUModelRunner(
)
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens, activate_lora in compilation_cases:
for num_tokens, (activate_lora, num_active_loras) in compilation_cases:
# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
@ -4748,6 +4785,7 @@ class GPUModelRunner(
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
num_active_loras=num_active_loras,
)
self._dummy_run(
num_tokens,
@ -4757,6 +4795,7 @@ class GPUModelRunner(
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
num_active_loras=num_active_loras,
is_graph_capturing=True,
)
self.maybe_remove_all_loras(self.lora_config)

View File

@ -4,7 +4,10 @@
Define LoRA functionality mixin for model runners.
"""
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, TypeAlias
import numpy as np
import torch
@ -20,7 +23,8 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
InputBatch = TPUInputBatch | GPUInputBatch
if TYPE_CHECKING:
InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch
logger = init_logger(__name__)
@ -129,7 +133,20 @@ class LoRAModelRunnerMixin:
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray | None = None,
activate_lora: bool = True,
num_active_loras: int = 0,
):
"""
Context manager to select dummy LoRAs for capture/warmup.
Args:
lora_config: LoRA configuration, or None if LoRA is disabled.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
activate_lora: Whether to activate LoRAs (False means no LoRA).
num_active_loras: Number of distinct active LoRAs to use.
- 0: Use all max_loras (default behavior for has_lora=True).
- >0: Use exactly this many distinct LoRAs.
"""
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
@ -140,13 +157,24 @@ class LoRAModelRunnerMixin:
assert self.lora_manager is not None, "LoRA is not enabled"
num_reqs = len(num_scheduled_tokens)
num_loras = lora_config.max_loras
max_loras = lora_config.max_loras
# Determine how many distinct LoRAs to use
if not activate_lora:
# No LoRA active
effective_num_loras = 0
elif num_active_loras > 0:
# Specific number of active LoRAs requested
effective_num_loras = min(num_active_loras, max_loras)
else:
# Default: use all max_loras
effective_num_loras = max_loras
# Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
if activate_lora:
if effective_num_loras > 0:
prompt_lora_mapping = (
np.arange(num_reqs, dtype=np.int32) % num_loras
np.arange(num_reqs, dtype=np.int32) % effective_num_loras
) + 1
else:
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
@ -157,19 +185,20 @@ class LoRAModelRunnerMixin:
# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
# Make dummy lora requests
# Make dummy lora requests (only for the active LoRAs)
lora_requests: set[LoRARequest] = {
LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
for lora_id in range(1, num_loras + 1)
for lora_id in range(1, effective_num_loras + 1)
}
self._set_active_loras(
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
)
if lora_requests:
self._set_active_loras(
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
)
yield
@ -181,11 +210,27 @@ class LoRAModelRunnerMixin:
num_sampled_tokens: np.ndarray,
activate_lora: bool = True,
remove_lora: bool = True,
num_active_loras: int = 0,
):
"""
Context manager for dummy runs with LoRA.
Args:
lora_config: LoRA configuration.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
activate_lora: Whether to activate LoRAs.
remove_lora: Whether to remove LoRAs after the context exits.
num_active_loras: Number of distinct active LoRAs to use.
"""
with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora
lora_config,
num_scheduled_tokens,
num_sampled_tokens,
activate_lora,
num_active_loras,
),
):
yield