From 288d67d054283d1b1f7346d7bcb74496da21fa04 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Wed, 17 Dec 2025 20:02:06 +0000 Subject: [PATCH 1/8] Using active-loras in grid in fused_moe_lora kernel Signed-off-by: Yu Gong --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index f04936221eea6..e1ea092fcd428 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -81,6 +81,7 @@ def _fused_moe_lora_kernel( # Meta-parameters num_slice_a: tl.constexpr, num_slice_c: tl.constexpr, + max_loras: tl.constexpr, top_k: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, @@ -104,7 +105,7 @@ def _fused_moe_lora_kernel( if moe_enabled == 0: # Early exit for the no moe lora case. return - max_loras = tl.num_programs(axis=2) + # max_loras = tl.num_programs(axis=2) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) # calculate pid_m,pid_n @@ -228,6 +229,7 @@ def _fused_moe_lora_shrink( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, ) -> None: w1_lora_a_stacked = lora_a_stacked[0] @@ -251,7 +253,7 @@ def _fused_moe_lora_shrink( * triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_a_stacked), - lora_a_stacked[0].shape[0], + num_active_loras, ) _fused_moe_lora_kernel[grid]( qcurr_hidden_states, @@ -280,6 +282,7 @@ def _fused_moe_lora_shrink( expert_ids.stride(0), slice_a_size=qcurr_hidden_states.numel(), slice_c_size=a_intermediate_cache1.numel() // num_slices, + max_loras=lora_a_stacked[0].shape[0], num_slice_a=1, num_slice_c=num_slices, top_k=1 if mul_routed_weight else top_k_num, @@ -322,6 +325,7 @@ def _fused_moe_lora_expand( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, offset: int = 0, ) -> None: @@ -351,7 +355,7 @@ def _fused_moe_lora_expand( grid = lambda META: ( triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_b_stacked), - lora_b_stacked[0].shape[0], + num_active_loras, ) _fused_moe_lora_kernel[grid]( a_intermediate_cache1, @@ -382,6 +386,7 @@ def _fused_moe_lora_expand( slice_c_size=b_intermediate_cache1.numel() // num_slices, num_slice_a=num_slices, num_slice_c=num_slices, + max_loras=lora_b_stacked[0].shape[0], top_k=1, MUL_ROUTED_WEIGHT=mul_routed_weight, IS_PRIMARY=False, @@ -492,6 +497,7 @@ def _fused_moe_lora( shrink_num_warps, shrink_num_stages, shrink_split_k, + num_active_loras, mul_routed_weight, ) @@ -538,6 +544,7 @@ def _fused_moe_lora( expand_num_warps, expand_num_stages, expand_split_k, + num_active_loras, mul_routed_weight, offset, ) @@ -601,6 +608,7 @@ def _fused_moe_lora_shrink_fake( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, ) -> None: return @@ -634,6 +642,7 @@ def _fused_moe_lora_expand_fake( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, ) -> None: return From ea1a26d2768da61250d48b7f99d378f492bbad4a Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Thu, 18 Dec 2025 19:16:02 +0000 Subject: [PATCH 2/8] Capture multiple cuda graph across various active loras Signed-off-by: Yu Gong --- vllm/forward_context.py | 14 +++- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 6 ++ vllm/lora/ops/triton_ops/lora_expand_op.py | 2 + .../ops/triton_ops/lora_kernel_metadata.py | 10 +++ vllm/lora/ops/triton_ops/lora_shrink_op.py | 2 + vllm/lora/punica_wrapper/punica_gpu.py | 9 +- vllm/v1/cudagraph_dispatcher.py | 84 +++++++++++++++---- vllm/v1/worker/gpu_model_runner.py | 65 +++++++++++--- vllm/v1/worker/lora_model_runner_mixin.py | 65 +++++++++++--- 9 files changed, 216 insertions(+), 41 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 7a569ec32eac9..77793b8a82aa8 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -46,6 +46,14 @@ class BatchDescriptor(NamedTuple): """ Whether this batch has active LoRA adapters. """ + num_active_loras: int = 0 + """ + Number of distinct active LoRA adapters in this batch. + When cudagraph_specialize_lora_count is enabled, separate CUDA graphs + are captured for each num_active_loras value. This allows kernels + (like fused_moe_lora) whose grid size depends on num_active_loras + to be properly captured. + """ def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor": """ @@ -53,7 +61,11 @@ class BatchDescriptor(NamedTuple): with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs). """ return BatchDescriptor( - self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora + self.num_tokens, + num_reqs=None, + uniform=False, + has_lora=self.has_lora, + num_active_loras=self.num_active_loras, ) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index e1ea092fcd428..65c2ccb963e33 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -1,17 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import torch from vllm.distributed import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) +from vllm.logger import init_logger from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op from .utils import supports_pdl +logger = init_logger(__name__) + _LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} @@ -413,6 +417,7 @@ def _fused_moe_lora( max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, + num_active_loras: int, adapter_enabled: torch.Tensor, shrink_block_size_m: int, shrink_block_size_n: int, @@ -562,6 +567,7 @@ def _fused_moe_lora_fake( max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, + num_active_loras: int, adapter_enabled: torch.Tensor, shrink_block_size_m: int, shrink_block_size_n: int, diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index 311c4b1918597..a2a498c4067c3 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -140,6 +140,7 @@ def _lora_expand( lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] no_lora_flag_cpu: torch.Tensor, # shape [1] + num_active_loras: int, # number of active LoRAs (unused here, for API compat) offset_start: int = 0, add_inputs: bool = False, ) -> None: @@ -291,6 +292,7 @@ def _lora_expand_fake( lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, no_lora_flag_cpu: torch.Tensor, + num_active_loras: int, offset_start: int = 0, add_inputs: bool = False, ) -> None: diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index c3bef7680dd0d..8dfa7435080b0 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -28,6 +28,10 @@ class LoRAKernelMeta: # to early exit from inside the lora_expand / lora_shrink torch operation. no_lora_flag_cpu: torch.Tensor + # Number of active LoRAs (unique non-(-1) values in token_lora_mapping) + # Stored as a Python int to avoid GPU->CPU sync during forward pass + num_active_loras: int = 0 + @staticmethod def make( max_loras: int, max_num_tokens: int, device: torch.device | str @@ -73,6 +77,7 @@ class LoRAKernelMeta: self.num_tokens_per_lora.fill_(0) self.lora_token_start_loc.fill_(0) self.no_lora_flag_cpu.fill_(False) + self.num_active_loras = 0 def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: """ @@ -117,6 +122,9 @@ class LoRAKernelMeta: self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_( num_tokens_per_lora, non_blocking=True ) + # Store num_active_loras as Python int (excludes -1 which means no LoRA) + # Valid LoRA IDs are >= 0, so count those + self.num_active_loras = int((lora_ids >= 0).sum().item()) # lora_token_start_loc lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) @@ -133,6 +141,7 @@ class LoRAKernelMeta: torch.Tensor, torch.Tensor, torch.Tensor, + int, ]: """ This function returns the kernel metadata required for the current @@ -151,4 +160,5 @@ class LoRAKernelMeta: self.lora_token_start_loc, self.active_lora_ids, self.no_lora_flag_cpu, + self.num_active_loras, ) diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 71bd5e3614667..240117b193506 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -136,6 +136,7 @@ def _lora_shrink( lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] no_lora_flag_cpu: torch.Tensor, # shape [1] + num_active_loras: int, # number of active LoRAs (unused here, for API compat) scaling: float, ) -> None: """ @@ -269,6 +270,7 @@ def _lora_shrink_fake( lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, no_lora_flag_cpu: torch.Tensor, + num_active_loras: int, scaling: float, ) -> None: return diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index ef4b4ab7c3497..46674d5c32f87 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -333,8 +333,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): (max_loras), dtype=torch.int32, device=topk_ids.device ) - (token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args( - num_tokens + (token_lora_mapping, _, _, _, lora_ids, _, _) = ( + self.token_mapping_meta.meta_args(num_tokens) ) ops.moe_lora_align_block_size( @@ -378,7 +378,9 @@ class PunicaWrapperGPU(PunicaWrapperBase): """ Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. """ - (_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0)) + (_, _, _, _, lora_ids, _, num_active_loras) = self.token_mapping_meta.meta_args( + x.size(0) + ) fused_moe_lora( y, x, @@ -391,6 +393,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): max_lora_rank, top_k_num, lora_ids, + num_active_loras, adapter_enabled, shrink_config.get("BLOCK_SIZE_M", 64), shrink_config.get("BLOCK_SIZE_N", 64), diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 8a3500c0aac6b..fc59f433f3632 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -57,9 +57,14 @@ class CudagraphDispatcher: ) self.keys_initialized = False + self.specialize_lora_count = False def _create_padded_batch_descriptor( - self, num_tokens: int, uniform_decode: bool, has_lora: bool + self, + num_tokens: int, + uniform_decode: bool, + has_lora: bool, + num_active_loras: int = 0, ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs uniform_decode_query_len = self.uniform_decode_query_len @@ -77,6 +82,7 @@ class CudagraphDispatcher: num_reqs=num_reqs, uniform=uniform_decode, has_lora=has_lora, + num_active_loras=num_active_loras, ) def add_cudagraph_key( @@ -87,6 +93,37 @@ class CudagraphDispatcher: ) self.cudagraph_keys[runtime_mode].add(batch_descriptor) + def _get_lora_cases(self) -> list[tuple[bool, int]]: + """ + Returns list of (has_lora, num_active_loras) tuples for graph capture. + + Returns cases for each num_active_loras from 1 to max_loras. + If cudagraph_specialize_lora is True, also includes the no-lora case. + + Note: When speculative decoding is enabled, we fall back to capturing + only with max_loras to avoid conflicts with torch.compile during + CUDA graph capture. + """ + if not self.vllm_config.lora_config: + return [(False, 0)] + + max_loras = self.vllm_config.lora_config.max_loras + + # When speculative decoding is enabled, only capture with max_loras + # to avoid torch.compile conflicts during CUDA graph capture + if self.vllm_config.speculative_config is not None: + lora_cases = [(True, max_loras)] + if self.compilation_config.cudagraph_specialize_lora: + lora_cases.append((False, 0)) + return lora_cases + + # Capture for each num_active_loras from 1 to max_loras + lora_cases = [(True, n) for n in range(1, max_loras + 1)] + # Also capture the no-lora case + if self.compilation_config.cudagraph_specialize_lora: + lora_cases.append((False, 0)) + return lora_cases + def initialize_cudagraph_keys( self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int ): @@ -94,26 +131,23 @@ class CudagraphDispatcher: # get the correct cudagraph mode after backend support is resolved. self.cudagraph_mode = cudagraph_mode - # LoRA activation cases to specialize the cuda graphs on - if self.vllm_config.lora_config: - if self.compilation_config.cudagraph_specialize_lora: - lora_cases = [True, False] - else: - lora_cases = [True] - else: - lora_cases = [False] + # Track whether we have LoRA config (always specialize on count) + self.has_lora_config = self.vllm_config.lora_config is not None + + # Get LoRA cases to capture + lora_cases = self._get_lora_cases() # Note: we create all valid keys for cudagraph here but do not # guarantee all keys would be used. For example, if we allow lazy # capturing in future PR, some keys may never be triggered. if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - for bs, has_lora in product( + for bs, (has_lora, num_active_loras) in product( self.compilation_config.cudagraph_capture_sizes, lora_cases ): self.add_cudagraph_key( cudagraph_mode.mixed_mode(), self._create_padded_batch_descriptor( - bs, False, has_lora + bs, False, has_lora, num_active_loras ).relax_for_mixed_batch_cudagraphs(), ) @@ -132,10 +166,14 @@ class CudagraphDispatcher: for x in self.compilation_config.cudagraph_capture_sizes if x <= max_num_tokens and x >= uniform_decode_query_len ] - for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): + for bs, (has_lora, num_active_loras) in product( + cudagraph_capture_sizes_for_decode, lora_cases + ): self.add_cudagraph_key( CUDAGraphMode.FULL, - self._create_padded_batch_descriptor(bs, True, has_lora), + self._create_padded_batch_descriptor( + bs, True, has_lora, num_active_loras + ), ) self.keys_initialized = True @@ -146,12 +184,20 @@ class CudagraphDispatcher: uniform_decode: bool, has_lora: bool, disable_full: bool = False, + num_active_loras: int = 0, ) -> tuple[CUDAGraphMode, BatchDescriptor]: """ Given conditions(e.g.,batch descriptor and if using cascade attention), dispatch to a cudagraph runtime mode and the valid batch descriptor. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). + + Args: + num_tokens: Number of tokens in the batch. + uniform_decode: Whether this is a uniform decode batch. + has_lora: Whether this batch has active LoRA adapters. + disable_full: Whether to disable full cudagraph mode. + num_active_loras: Number of distinct active LoRA adapters. """ if ( not self.keys_initialized @@ -160,8 +206,18 @@ class CudagraphDispatcher: ): return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) + # When speculative decoding is enabled, always use max_loras for lookup + # since we only capture graphs with max_loras + effective_num_active_loras = num_active_loras + if ( + self.vllm_config.speculative_config is not None + and self.vllm_config.lora_config is not None + and has_lora + ): + effective_num_active_loras = self.vllm_config.lora_config.max_loras + batch_desc = self._create_padded_batch_descriptor( - num_tokens, uniform_decode, has_lora + num_tokens, uniform_decode, has_lora, effective_num_active_loras ) relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 00c585aaaacbb..673eb5e6df2d2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2820,6 +2820,7 @@ class GPUModelRunner( # be improved in model runner v2) force_uniform_decode: bool | None = None, force_has_lora: bool | None = None, + force_num_active_loras: int | None = None, num_encoder_reqs: int = 0, ) -> tuple[ CUDAGraphMode, @@ -2841,11 +2842,13 @@ class GPUModelRunner( self.model_config.is_encoder_decoder and num_encoder_reqs > 0 ) - has_lora = ( - len(self.input_batch.lora_id_to_lora_request) > 0 - if force_has_lora is None - else force_has_lora + # Compute LoRA state for cudagraph dispatch + num_active_loras = ( + force_num_active_loras + if force_num_active_loras is not None + else len(self.input_batch.lora_id_to_lora_request) ) + has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) dispatch_cudagraph = ( @@ -2854,6 +2857,7 @@ class GPUModelRunner( has_lora=has_lora, uniform_decode=uniform_decode, disable_full=disable_full, + num_active_loras=num_active_loras, ) if not force_eager else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) @@ -4058,6 +4062,7 @@ class GPUModelRunner( remove_lora: bool = True, activate_lora: bool = False, is_graph_capturing: bool = False, + num_active_loras: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -4081,6 +4086,8 @@ class GPUModelRunner( (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run activate_lora: If False, dummy_run is performed without LoRAs. + num_active_loras: Number of distinct active LoRAs to capture for. + Only used when cudagraph_specialize_lora_count is True. """ if supports_mm_encoder_only(self.model): # The current dummy run only covers LM execution, so we can skip it. @@ -4162,6 +4169,9 @@ class GPUModelRunner( # activated later in the context manager, but we need to know the # LoRA state when determining the batch descriptor for capture force_has_lora=activate_lora, + # `force_num_active_loras` is used for cudagraph capture; because we + # need to capture graphs for specific num_active_loras counts + force_num_active_loras=num_active_loras if activate_lora else 0, ) ) @@ -4225,6 +4235,7 @@ class GPUModelRunner( num_sampled_tokens, activate_lora, remove_lora, + num_active_loras, ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens @@ -4587,6 +4598,24 @@ class GPUModelRunner( self.encoder_cache.clear() gc.collect() + def _get_lora_capture_cases(self) -> list[tuple[bool, int]]: + """ + Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture. + + Returns cases for each num_active_loras from 1 to max_loras. + If cudagraph_specialize_lora is True, also includes the no-lora case. + """ + if not self.lora_config: + return [(False, 0)] + + max_loras = self.lora_config.max_loras + # Capture for each num_active_loras from 1 to max_loras + lora_cases = [(True, n) for n in range(1, max_loras + 1)] + # Also capture the no-lora case if cudagraph_specialize_lora is True + if self.compilation_config.cudagraph_specialize_lora: + lora_cases.append((False, 0)) + return lora_cases + def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( @@ -4624,13 +4653,8 @@ class GPUModelRunner( cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - if self.lora_config: - if self.compilation_config.cudagraph_specialize_lora: - lora_cases = [True, False] - else: - lora_cases = [True] - else: - lora_cases = [False] + # Build LoRA cases: list of (has_lora, num_active_loras) tuples + lora_cases = self._get_lora_capture_cases() if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode() @@ -4695,10 +4719,23 @@ class GPUModelRunner( def _capture_cudagraphs( self, - compilation_cases: list[tuple[int, bool]], + compilation_cases: list[tuple[int, tuple[bool, int]]], cudagraph_runtime_mode: CUDAGraphMode, uniform_decode: bool, ): + """ + Capture CUDA graphs for the given compilation cases. + + Args: + compilation_cases: List of (num_tokens, (has_lora, num_active_loras)) + tuples. + - num_tokens: batch size to capture + - has_lora: whether LoRA is active for this capture + - num_active_loras: number of distinct active LoRAs + (0 if not specializing) + cudagraph_runtime_mode: FULL or PIECEWISE cudagraph mode. + uniform_decode: Whether this is a uniform decode batch. + """ assert ( cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode.valid_runtime_modes() @@ -4716,7 +4753,7 @@ class GPUModelRunner( ) # We skip EPLB here since we don't want to record dummy metrics - for num_tokens, activate_lora in compilation_cases: + for num_tokens, (activate_lora, num_active_loras) in compilation_cases: # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched @@ -4748,6 +4785,7 @@ class GPUModelRunner( skip_eplb=True, remove_lora=False, activate_lora=activate_lora, + num_active_loras=num_active_loras, ) self._dummy_run( num_tokens, @@ -4757,6 +4795,7 @@ class GPUModelRunner( skip_eplb=True, remove_lora=False, activate_lora=activate_lora, + num_active_loras=num_active_loras, is_graph_capturing=True, ) self.maybe_remove_all_loras(self.lora_config) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index a67246146005c..06c933d214b79 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -4,7 +4,10 @@ Define LoRA functionality mixin for model runners. """ +from __future__ import annotations + from contextlib import contextmanager +from typing import TYPE_CHECKING, TypeAlias import numpy as np import torch @@ -20,7 +23,8 @@ from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch -InputBatch = TPUInputBatch | GPUInputBatch +if TYPE_CHECKING: + InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch logger = init_logger(__name__) @@ -129,7 +133,20 @@ class LoRAModelRunnerMixin: num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray | None = None, activate_lora: bool = True, + num_active_loras: int = 0, ): + """ + Context manager to select dummy LoRAs for capture/warmup. + + Args: + lora_config: LoRA configuration, or None if LoRA is disabled. + num_scheduled_tokens: Array of scheduled token counts per request. + num_sampled_tokens: Array of sampled token counts per request. + activate_lora: Whether to activate LoRAs (False means no LoRA). + num_active_loras: Number of distinct active LoRAs to use. + - 0: Use all max_loras (default behavior for has_lora=True). + - >0: Use exactly this many distinct LoRAs. + """ if num_sampled_tokens is None: num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32) @@ -140,13 +157,24 @@ class LoRAModelRunnerMixin: assert self.lora_manager is not None, "LoRA is not enabled" num_reqs = len(num_scheduled_tokens) - num_loras = lora_config.max_loras + max_loras = lora_config.max_loras + + # Determine how many distinct LoRAs to use + if not activate_lora: + # No LoRA active + effective_num_loras = 0 + elif num_active_loras > 0: + # Specific number of active LoRAs requested + effective_num_loras = min(num_active_loras, max_loras) + else: + # Default: use all max_loras + effective_num_loras = max_loras # Make prompt lora mapping # Assign LoRA IDs cyclically to simulate a worst-case scenario. - if activate_lora: + if effective_num_loras > 0: prompt_lora_mapping = ( - np.arange(num_reqs, dtype=np.int32) % num_loras + np.arange(num_reqs, dtype=np.int32) % effective_num_loras ) + 1 else: prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32) @@ -157,19 +185,20 @@ class LoRAModelRunnerMixin: # Make token lora mapping token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) - # Make dummy lora requests + # Make dummy lora requests (only for the active LoRAs) lora_requests: set[LoRARequest] = { LoRARequest( lora_name=f"warmup_{lora_id}", lora_int_id=lora_id, lora_path="/not/a/real/path", ) - for lora_id in range(1, num_loras + 1) + for lora_id in range(1, effective_num_loras + 1) } - self._set_active_loras( - tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests - ) + if lora_requests: + self._set_active_loras( + tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests + ) yield @@ -181,11 +210,27 @@ class LoRAModelRunnerMixin: num_sampled_tokens: np.ndarray, activate_lora: bool = True, remove_lora: bool = True, + num_active_loras: int = 0, ): + """ + Context manager for dummy runs with LoRA. + + Args: + lora_config: LoRA configuration. + num_scheduled_tokens: Array of scheduled token counts per request. + num_sampled_tokens: Array of sampled token counts per request. + activate_lora: Whether to activate LoRAs. + remove_lora: Whether to remove LoRAs after the context exits. + num_active_loras: Number of distinct active LoRAs to use. + """ with ( self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_select_dummy_loras( - lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora + lora_config, + num_scheduled_tokens, + num_sampled_tokens, + activate_lora, + num_active_loras, ), ): yield From aa7917aaaaea805146d67720e043fde1c7b7fbe2 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Thu, 18 Dec 2025 21:09:47 +0000 Subject: [PATCH 3/8] Clean code Signed-off-by: Yu Gong --- vllm/v1/cudagraph_dispatcher.py | 8 -------- vllm/v1/worker/gpu_model_runner.py | 20 +------------------- 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index fc59f433f3632..209dc85b5bbc9 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -109,14 +109,6 @@ class CudagraphDispatcher: max_loras = self.vllm_config.lora_config.max_loras - # When speculative decoding is enabled, only capture with max_loras - # to avoid torch.compile conflicts during CUDA graph capture - if self.vllm_config.speculative_config is not None: - lora_cases = [(True, max_loras)] - if self.compilation_config.cudagraph_specialize_lora: - lora_cases.append((False, 0)) - return lora_cases - # Capture for each num_active_loras from 1 to max_loras lora_cases = [(True, n) for n in range(1, max_loras + 1)] # Also capture the no-lora case diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 673eb5e6df2d2..fabb5f07ba683 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4598,24 +4598,6 @@ class GPUModelRunner( self.encoder_cache.clear() gc.collect() - def _get_lora_capture_cases(self) -> list[tuple[bool, int]]: - """ - Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture. - - Returns cases for each num_active_loras from 1 to max_loras. - If cudagraph_specialize_lora is True, also includes the no-lora case. - """ - if not self.lora_config: - return [(False, 0)] - - max_loras = self.lora_config.max_loras - # Capture for each num_active_loras from 1 to max_loras - lora_cases = [(True, n) for n in range(1, max_loras + 1)] - # Also capture the no-lora case if cudagraph_specialize_lora is True - if self.compilation_config.cudagraph_specialize_lora: - lora_cases.append((False, 0)) - return lora_cases - def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( @@ -4654,7 +4636,7 @@ class GPUModelRunner( assert cudagraph_mode is not None # Build LoRA cases: list of (has_lora, num_active_loras) tuples - lora_cases = self._get_lora_capture_cases() + lora_cases = self.cudagraph_dispatcher._get_lora_cases() if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode() From 3b5270a6afbefe38c3367ff1957cc03901c06a69 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Thu, 18 Dec 2025 21:22:25 +0000 Subject: [PATCH 4/8] fix bug of always capture lora even no-lora case Signed-off-by: Yu Gong --- vllm/v1/worker/lora_model_runner_mixin.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 06c933d214b79..01d73ccecf84a 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -195,10 +195,12 @@ class LoRAModelRunnerMixin: for lora_id in range(1, effective_num_loras + 1) } - if lora_requests: - self._set_active_loras( - tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests - ) + # Always call _set_active_loras to ensure the mapping is updated. + # This is important when capturing no-LoRA graphs (effective_num_loras=0) + # after capturing LoRA graphs, as we need to clear the previous mapping. + self._set_active_loras( + tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests + ) yield From b28be8ea577a2cc9b36d5852f8f4a4203669c76b Mon Sep 17 00:00:00 2001 From: Kevin McKay Date: Sun, 21 Dec 2025 23:14:27 -0600 Subject: [PATCH 5/8] [Misc] Fix typo: 'occured' -> 'occurred' (#31120) Signed-off-by: c0de128 From e475bdae39768d36a72c430fef9836ff3ee751d6 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Thu, 18 Dec 2025 19:16:02 +0000 Subject: [PATCH 6/8] Capture multiple cuda graph across various active loras Signed-off-by: Yu Gong --- vllm/v1/cudagraph_dispatcher.py | 30 +++++++----------------------- vllm/v1/worker/gpu_model_runner.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 209dc85b5bbc9..7fe11abee30d8 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -58,6 +58,7 @@ class CudagraphDispatcher: self.keys_initialized = False self.specialize_lora_count = False + self.specialize_lora_count = False def _create_padded_batch_descriptor( self, @@ -65,6 +66,11 @@ class CudagraphDispatcher: uniform_decode: bool, has_lora: bool, num_active_loras: int = 0, + self, + num_tokens: int, + uniform_decode: bool, + has_lora: bool, + num_active_loras: int = 0, ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs uniform_decode_query_len = self.uniform_decode_query_len @@ -83,6 +89,7 @@ class CudagraphDispatcher: uniform=uniform_decode, has_lora=has_lora, num_active_loras=num_active_loras, + num_active_loras=num_active_loras, ) def add_cudagraph_key( @@ -93,29 +100,6 @@ class CudagraphDispatcher: ) self.cudagraph_keys[runtime_mode].add(batch_descriptor) - def _get_lora_cases(self) -> list[tuple[bool, int]]: - """ - Returns list of (has_lora, num_active_loras) tuples for graph capture. - - Returns cases for each num_active_loras from 1 to max_loras. - If cudagraph_specialize_lora is True, also includes the no-lora case. - - Note: When speculative decoding is enabled, we fall back to capturing - only with max_loras to avoid conflicts with torch.compile during - CUDA graph capture. - """ - if not self.vllm_config.lora_config: - return [(False, 0)] - - max_loras = self.vllm_config.lora_config.max_loras - - # Capture for each num_active_loras from 1 to max_loras - lora_cases = [(True, n) for n in range(1, max_loras + 1)] - # Also capture the no-lora case - if self.compilation_config.cudagraph_specialize_lora: - lora_cases.append((False, 0)) - return lora_cases - def initialize_cudagraph_keys( self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int ): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fabb5f07ba683..d8794e8e4fcaa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4598,6 +4598,24 @@ class GPUModelRunner( self.encoder_cache.clear() gc.collect() + def _get_lora_capture_cases(self) -> list[tuple[bool, int]]: + """ + Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture. + + Returns cases for each num_active_loras from 1 to max_loras. + If cudagraph_specialize_lora is True, also includes the no-lora case. + """ + if not self.lora_config: + return [(False, 0)] + + max_loras = self.lora_config.max_loras + # Capture for each num_active_loras from 1 to max_loras + lora_cases = [(True, n) for n in range(1, max_loras + 1)] + # Also capture the no-lora case if cudagraph_specialize_lora is True + if self.compilation_config.cudagraph_specialize_lora: + lora_cases.append((False, 0)) + return lora_cases + def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( From e95e2145c5ac1f1c9478f9940cc951cf07e6a003 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 23 Dec 2025 19:25:07 +0000 Subject: [PATCH 7/8] Construting grid with num of active lora Signed-off-by: Yu Gong --- vllm/lora/ops/triton_ops/lora_expand_op.py | 2 +- vllm/lora/ops/triton_ops/lora_shrink_op.py | 2 +- vllm/v1/cudagraph_dispatcher.py | 26 ++++++++++++++++------ vllm/v1/worker/gpu_model_runner.py | 18 --------------- 4 files changed, 21 insertions(+), 27 deletions(-) diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index a2a498c4067c3..de2d70a7fbbdc 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -240,7 +240,7 @@ def _lora_expand( # Each LoRA receives its own set of thread blocks for output # computation. If some LoRA doesn't have any tokens to process, its # thread blocks simply exit. - MAX_LORAS, + num_active_loras, ) use_gdc = supports_pdl(inputs.device) _lora_expand_kernel[grid]( diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 240117b193506..912c546ace8b0 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -220,7 +220,7 @@ def _lora_shrink( # Each LoRA receives its own set of thread blocks for output # computation. If some LoRA doesn't have any tokens to process, its # thread blocks exit early. - MAX_LORAS, + num_active_loras, ) use_gdc = supports_pdl(inputs.device) _lora_shrink_kernel[grid]( diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 7fe11abee30d8..9d731bde644d3 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -58,7 +58,25 @@ class CudagraphDispatcher: self.keys_initialized = False self.specialize_lora_count = False - self.specialize_lora_count = False + + def _get_lora_cases(self) -> list[tuple[bool, int]]: + """ + Returns list of (has_lora, num_active_loras) tuples for CUDA graph + capture. This is the single source of truth for LoRA capture cases. + """ + lora_config = self.vllm_config.lora_config + if lora_config is None: + # No LoRA configured - single case with no LoRA + return [(False, 0)] + + # LoRA is enabled - capture graphs for different active LoRA counts + # Always include the no-LoRA case (for requests without adapters) + cases: list[tuple[bool, int]] = [(False, 0)] + + for n in range(1, lora_config.max_loras + 1): + cases.append((True, n)) + + return cases def _create_padded_batch_descriptor( self, @@ -66,11 +84,6 @@ class CudagraphDispatcher: uniform_decode: bool, has_lora: bool, num_active_loras: int = 0, - self, - num_tokens: int, - uniform_decode: bool, - has_lora: bool, - num_active_loras: int = 0, ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs uniform_decode_query_len = self.uniform_decode_query_len @@ -89,7 +102,6 @@ class CudagraphDispatcher: uniform=uniform_decode, has_lora=has_lora, num_active_loras=num_active_loras, - num_active_loras=num_active_loras, ) def add_cudagraph_key( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d8794e8e4fcaa..fabb5f07ba683 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4598,24 +4598,6 @@ class GPUModelRunner( self.encoder_cache.clear() gc.collect() - def _get_lora_capture_cases(self) -> list[tuple[bool, int]]: - """ - Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture. - - Returns cases for each num_active_loras from 1 to max_loras. - If cudagraph_specialize_lora is True, also includes the no-lora case. - """ - if not self.lora_config: - return [(False, 0)] - - max_loras = self.lora_config.max_loras - # Capture for each num_active_loras from 1 to max_loras - lora_cases = [(True, n) for n in range(1, max_loras + 1)] - # Also capture the no-lora case if cudagraph_specialize_lora is True - if self.compilation_config.cudagraph_specialize_lora: - lora_cases.append((False, 0)) - return lora_cases - def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( From 094eaef7b31f5c0bfd6c200a972ad332be374fcf Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 23 Dec 2025 20:00:19 +0000 Subject: [PATCH 8/8] remove the constraint for SD Signed-off-by: Yu Gong --- vllm/v1/cudagraph_dispatcher.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 9d731bde644d3..85cd3105571d8 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -194,15 +194,7 @@ class CudagraphDispatcher: ): return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) - # When speculative decoding is enabled, always use max_loras for lookup - # since we only capture graphs with max_loras effective_num_active_loras = num_active_loras - if ( - self.vllm_config.speculative_config is not None - and self.vllm_config.lora_config is not None - and has_lora - ): - effective_num_active_loras = self.vllm_config.lora_config.max_loras batch_desc = self._create_padded_batch_descriptor( num_tokens, uniform_decode, has_lora, effective_num_active_loras