diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 7a569ec32eac9..77793b8a82aa8 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -46,6 +46,14 @@ class BatchDescriptor(NamedTuple): """ Whether this batch has active LoRA adapters. """ + num_active_loras: int = 0 + """ + Number of distinct active LoRA adapters in this batch. + When cudagraph_specialize_lora_count is enabled, separate CUDA graphs + are captured for each num_active_loras value. This allows kernels + (like fused_moe_lora) whose grid size depends on num_active_loras + to be properly captured. + """ def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor": """ @@ -53,7 +61,11 @@ class BatchDescriptor(NamedTuple): with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs). """ return BatchDescriptor( - self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora + self.num_tokens, + num_reqs=None, + uniform=False, + has_lora=self.has_lora, + num_active_loras=self.num_active_loras, ) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index f04936221eea6..65c2ccb963e33 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -1,17 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import torch from vllm.distributed import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) +from vllm.logger import init_logger from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op from .utils import supports_pdl +logger = init_logger(__name__) + _LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} @@ -81,6 +85,7 @@ def _fused_moe_lora_kernel( # Meta-parameters num_slice_a: tl.constexpr, num_slice_c: tl.constexpr, + max_loras: tl.constexpr, top_k: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, @@ -104,7 +109,7 @@ def _fused_moe_lora_kernel( if moe_enabled == 0: # Early exit for the no moe lora case. return - max_loras = tl.num_programs(axis=2) + # max_loras = tl.num_programs(axis=2) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) # calculate pid_m,pid_n @@ -228,6 +233,7 @@ def _fused_moe_lora_shrink( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, ) -> None: w1_lora_a_stacked = lora_a_stacked[0] @@ -251,7 +257,7 @@ def _fused_moe_lora_shrink( * triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_a_stacked), - lora_a_stacked[0].shape[0], + num_active_loras, ) _fused_moe_lora_kernel[grid]( qcurr_hidden_states, @@ -280,6 +286,7 @@ def _fused_moe_lora_shrink( expert_ids.stride(0), slice_a_size=qcurr_hidden_states.numel(), slice_c_size=a_intermediate_cache1.numel() // num_slices, + max_loras=lora_a_stacked[0].shape[0], num_slice_a=1, num_slice_c=num_slices, top_k=1 if mul_routed_weight else top_k_num, @@ -322,6 +329,7 @@ def _fused_moe_lora_expand( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, offset: int = 0, ) -> None: @@ -351,7 +359,7 @@ def _fused_moe_lora_expand( grid = lambda META: ( triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_b_stacked), - lora_b_stacked[0].shape[0], + num_active_loras, ) _fused_moe_lora_kernel[grid]( a_intermediate_cache1, @@ -382,6 +390,7 @@ def _fused_moe_lora_expand( slice_c_size=b_intermediate_cache1.numel() // num_slices, num_slice_a=num_slices, num_slice_c=num_slices, + max_loras=lora_b_stacked[0].shape[0], top_k=1, MUL_ROUTED_WEIGHT=mul_routed_weight, IS_PRIMARY=False, @@ -408,6 +417,7 @@ def _fused_moe_lora( max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, + num_active_loras: int, adapter_enabled: torch.Tensor, shrink_block_size_m: int, shrink_block_size_n: int, @@ -492,6 +502,7 @@ def _fused_moe_lora( shrink_num_warps, shrink_num_stages, shrink_split_k, + num_active_loras, mul_routed_weight, ) @@ -538,6 +549,7 @@ def _fused_moe_lora( expand_num_warps, expand_num_stages, expand_split_k, + num_active_loras, mul_routed_weight, offset, ) @@ -555,6 +567,7 @@ def _fused_moe_lora_fake( max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, + num_active_loras: int, adapter_enabled: torch.Tensor, shrink_block_size_m: int, shrink_block_size_n: int, @@ -601,6 +614,7 @@ def _fused_moe_lora_shrink_fake( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, ) -> None: return @@ -634,6 +648,7 @@ def _fused_moe_lora_expand_fake( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, ) -> None: return diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index 311c4b1918597..de2d70a7fbbdc 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -140,6 +140,7 @@ def _lora_expand( lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] no_lora_flag_cpu: torch.Tensor, # shape [1] + num_active_loras: int, # number of active LoRAs (unused here, for API compat) offset_start: int = 0, add_inputs: bool = False, ) -> None: @@ -239,7 +240,7 @@ def _lora_expand( # Each LoRA receives its own set of thread blocks for output # computation. If some LoRA doesn't have any tokens to process, its # thread blocks simply exit. - MAX_LORAS, + num_active_loras, ) use_gdc = supports_pdl(inputs.device) _lora_expand_kernel[grid]( @@ -291,6 +292,7 @@ def _lora_expand_fake( lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, no_lora_flag_cpu: torch.Tensor, + num_active_loras: int, offset_start: int = 0, add_inputs: bool = False, ) -> None: diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index c3bef7680dd0d..8dfa7435080b0 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -28,6 +28,10 @@ class LoRAKernelMeta: # to early exit from inside the lora_expand / lora_shrink torch operation. no_lora_flag_cpu: torch.Tensor + # Number of active LoRAs (unique non-(-1) values in token_lora_mapping) + # Stored as a Python int to avoid GPU->CPU sync during forward pass + num_active_loras: int = 0 + @staticmethod def make( max_loras: int, max_num_tokens: int, device: torch.device | str @@ -73,6 +77,7 @@ class LoRAKernelMeta: self.num_tokens_per_lora.fill_(0) self.lora_token_start_loc.fill_(0) self.no_lora_flag_cpu.fill_(False) + self.num_active_loras = 0 def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: """ @@ -117,6 +122,9 @@ class LoRAKernelMeta: self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_( num_tokens_per_lora, non_blocking=True ) + # Store num_active_loras as Python int (excludes -1 which means no LoRA) + # Valid LoRA IDs are >= 0, so count those + self.num_active_loras = int((lora_ids >= 0).sum().item()) # lora_token_start_loc lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) @@ -133,6 +141,7 @@ class LoRAKernelMeta: torch.Tensor, torch.Tensor, torch.Tensor, + int, ]: """ This function returns the kernel metadata required for the current @@ -151,4 +160,5 @@ class LoRAKernelMeta: self.lora_token_start_loc, self.active_lora_ids, self.no_lora_flag_cpu, + self.num_active_loras, ) diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 71bd5e3614667..912c546ace8b0 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -136,6 +136,7 @@ def _lora_shrink( lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] no_lora_flag_cpu: torch.Tensor, # shape [1] + num_active_loras: int, # number of active LoRAs (unused here, for API compat) scaling: float, ) -> None: """ @@ -219,7 +220,7 @@ def _lora_shrink( # Each LoRA receives its own set of thread blocks for output # computation. If some LoRA doesn't have any tokens to process, its # thread blocks exit early. - MAX_LORAS, + num_active_loras, ) use_gdc = supports_pdl(inputs.device) _lora_shrink_kernel[grid]( @@ -269,6 +270,7 @@ def _lora_shrink_fake( lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, no_lora_flag_cpu: torch.Tensor, + num_active_loras: int, scaling: float, ) -> None: return diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index ef4b4ab7c3497..46674d5c32f87 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -333,8 +333,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): (max_loras), dtype=torch.int32, device=topk_ids.device ) - (token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args( - num_tokens + (token_lora_mapping, _, _, _, lora_ids, _, _) = ( + self.token_mapping_meta.meta_args(num_tokens) ) ops.moe_lora_align_block_size( @@ -378,7 +378,9 @@ class PunicaWrapperGPU(PunicaWrapperBase): """ Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. """ - (_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0)) + (_, _, _, _, lora_ids, _, num_active_loras) = self.token_mapping_meta.meta_args( + x.size(0) + ) fused_moe_lora( y, x, @@ -391,6 +393,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): max_lora_rank, top_k_num, lora_ids, + num_active_loras, adapter_enabled, shrink_config.get("BLOCK_SIZE_M", 64), shrink_config.get("BLOCK_SIZE_N", 64), diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 8a3500c0aac6b..85cd3105571d8 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -57,9 +57,33 @@ class CudagraphDispatcher: ) self.keys_initialized = False + self.specialize_lora_count = False + + def _get_lora_cases(self) -> list[tuple[bool, int]]: + """ + Returns list of (has_lora, num_active_loras) tuples for CUDA graph + capture. This is the single source of truth for LoRA capture cases. + """ + lora_config = self.vllm_config.lora_config + if lora_config is None: + # No LoRA configured - single case with no LoRA + return [(False, 0)] + + # LoRA is enabled - capture graphs for different active LoRA counts + # Always include the no-LoRA case (for requests without adapters) + cases: list[tuple[bool, int]] = [(False, 0)] + + for n in range(1, lora_config.max_loras + 1): + cases.append((True, n)) + + return cases def _create_padded_batch_descriptor( - self, num_tokens: int, uniform_decode: bool, has_lora: bool + self, + num_tokens: int, + uniform_decode: bool, + has_lora: bool, + num_active_loras: int = 0, ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs uniform_decode_query_len = self.uniform_decode_query_len @@ -77,6 +101,7 @@ class CudagraphDispatcher: num_reqs=num_reqs, uniform=uniform_decode, has_lora=has_lora, + num_active_loras=num_active_loras, ) def add_cudagraph_key( @@ -94,26 +119,23 @@ class CudagraphDispatcher: # get the correct cudagraph mode after backend support is resolved. self.cudagraph_mode = cudagraph_mode - # LoRA activation cases to specialize the cuda graphs on - if self.vllm_config.lora_config: - if self.compilation_config.cudagraph_specialize_lora: - lora_cases = [True, False] - else: - lora_cases = [True] - else: - lora_cases = [False] + # Track whether we have LoRA config (always specialize on count) + self.has_lora_config = self.vllm_config.lora_config is not None + + # Get LoRA cases to capture + lora_cases = self._get_lora_cases() # Note: we create all valid keys for cudagraph here but do not # guarantee all keys would be used. For example, if we allow lazy # capturing in future PR, some keys may never be triggered. if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - for bs, has_lora in product( + for bs, (has_lora, num_active_loras) in product( self.compilation_config.cudagraph_capture_sizes, lora_cases ): self.add_cudagraph_key( cudagraph_mode.mixed_mode(), self._create_padded_batch_descriptor( - bs, False, has_lora + bs, False, has_lora, num_active_loras ).relax_for_mixed_batch_cudagraphs(), ) @@ -132,10 +154,14 @@ class CudagraphDispatcher: for x in self.compilation_config.cudagraph_capture_sizes if x <= max_num_tokens and x >= uniform_decode_query_len ] - for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): + for bs, (has_lora, num_active_loras) in product( + cudagraph_capture_sizes_for_decode, lora_cases + ): self.add_cudagraph_key( CUDAGraphMode.FULL, - self._create_padded_batch_descriptor(bs, True, has_lora), + self._create_padded_batch_descriptor( + bs, True, has_lora, num_active_loras + ), ) self.keys_initialized = True @@ -146,12 +172,20 @@ class CudagraphDispatcher: uniform_decode: bool, has_lora: bool, disable_full: bool = False, + num_active_loras: int = 0, ) -> tuple[CUDAGraphMode, BatchDescriptor]: """ Given conditions(e.g.,batch descriptor and if using cascade attention), dispatch to a cudagraph runtime mode and the valid batch descriptor. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). + + Args: + num_tokens: Number of tokens in the batch. + uniform_decode: Whether this is a uniform decode batch. + has_lora: Whether this batch has active LoRA adapters. + disable_full: Whether to disable full cudagraph mode. + num_active_loras: Number of distinct active LoRA adapters. """ if ( not self.keys_initialized @@ -160,8 +194,10 @@ class CudagraphDispatcher: ): return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) + effective_num_active_loras = num_active_loras + batch_desc = self._create_padded_batch_descriptor( - num_tokens, uniform_decode, has_lora + num_tokens, uniform_decode, has_lora, effective_num_active_loras ) relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 414ae33c6251f..f309103322fd4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2814,6 +2814,7 @@ class GPUModelRunner( # be improved in model runner v2) force_uniform_decode: bool | None = None, force_has_lora: bool | None = None, + force_num_active_loras: int | None = None, num_encoder_reqs: int = 0, ) -> tuple[ CUDAGraphMode, @@ -2835,11 +2836,13 @@ class GPUModelRunner( self.model_config.is_encoder_decoder and num_encoder_reqs > 0 ) - has_lora = ( - len(self.input_batch.lora_id_to_lora_request) > 0 - if force_has_lora is None - else force_has_lora + # Compute LoRA state for cudagraph dispatch + num_active_loras = ( + force_num_active_loras + if force_num_active_loras is not None + else len(self.input_batch.lora_id_to_lora_request) ) + has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) dispatch_cudagraph = ( @@ -2848,6 +2851,7 @@ class GPUModelRunner( has_lora=has_lora, uniform_decode=uniform_decode, disable_full=disable_full, + num_active_loras=num_active_loras, ) if not force_eager else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) @@ -4052,6 +4056,7 @@ class GPUModelRunner( remove_lora: bool = True, activate_lora: bool = False, is_graph_capturing: bool = False, + num_active_loras: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -4075,6 +4080,8 @@ class GPUModelRunner( (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run activate_lora: If False, dummy_run is performed without LoRAs. + num_active_loras: Number of distinct active LoRAs to capture for. + Only used when cudagraph_specialize_lora_count is True. """ if supports_mm_encoder_only(self.model): # The current dummy run only covers LM execution, so we can skip it. @@ -4156,6 +4163,9 @@ class GPUModelRunner( # activated later in the context manager, but we need to know the # LoRA state when determining the batch descriptor for capture force_has_lora=activate_lora, + # `force_num_active_loras` is used for cudagraph capture; because we + # need to capture graphs for specific num_active_loras counts + force_num_active_loras=num_active_loras if activate_lora else 0, ) ) @@ -4219,6 +4229,7 @@ class GPUModelRunner( num_sampled_tokens, activate_lora, remove_lora, + num_active_loras, ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens @@ -4618,13 +4629,8 @@ class GPUModelRunner( cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - if self.lora_config: - if self.compilation_config.cudagraph_specialize_lora: - lora_cases = [True, False] - else: - lora_cases = [True] - else: - lora_cases = [False] + # Build LoRA cases: list of (has_lora, num_active_loras) tuples + lora_cases = self.cudagraph_dispatcher._get_lora_cases() if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode() @@ -4689,10 +4695,23 @@ class GPUModelRunner( def _capture_cudagraphs( self, - compilation_cases: list[tuple[int, bool]], + compilation_cases: list[tuple[int, tuple[bool, int]]], cudagraph_runtime_mode: CUDAGraphMode, uniform_decode: bool, ): + """ + Capture CUDA graphs for the given compilation cases. + + Args: + compilation_cases: List of (num_tokens, (has_lora, num_active_loras)) + tuples. + - num_tokens: batch size to capture + - has_lora: whether LoRA is active for this capture + - num_active_loras: number of distinct active LoRAs + (0 if not specializing) + cudagraph_runtime_mode: FULL or PIECEWISE cudagraph mode. + uniform_decode: Whether this is a uniform decode batch. + """ assert ( cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode.valid_runtime_modes() @@ -4710,7 +4729,7 @@ class GPUModelRunner( ) # We skip EPLB here since we don't want to record dummy metrics - for num_tokens, activate_lora in compilation_cases: + for num_tokens, (activate_lora, num_active_loras) in compilation_cases: # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched @@ -4742,6 +4761,7 @@ class GPUModelRunner( skip_eplb=True, remove_lora=False, activate_lora=activate_lora, + num_active_loras=num_active_loras, ) self._dummy_run( num_tokens, @@ -4751,6 +4771,7 @@ class GPUModelRunner( skip_eplb=True, remove_lora=False, activate_lora=activate_lora, + num_active_loras=num_active_loras, is_graph_capturing=True, ) self.maybe_remove_all_loras(self.lora_config) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index a67246146005c..01d73ccecf84a 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -4,7 +4,10 @@ Define LoRA functionality mixin for model runners. """ +from __future__ import annotations + from contextlib import contextmanager +from typing import TYPE_CHECKING, TypeAlias import numpy as np import torch @@ -20,7 +23,8 @@ from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch -InputBatch = TPUInputBatch | GPUInputBatch +if TYPE_CHECKING: + InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch logger = init_logger(__name__) @@ -129,7 +133,20 @@ class LoRAModelRunnerMixin: num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray | None = None, activate_lora: bool = True, + num_active_loras: int = 0, ): + """ + Context manager to select dummy LoRAs for capture/warmup. + + Args: + lora_config: LoRA configuration, or None if LoRA is disabled. + num_scheduled_tokens: Array of scheduled token counts per request. + num_sampled_tokens: Array of sampled token counts per request. + activate_lora: Whether to activate LoRAs (False means no LoRA). + num_active_loras: Number of distinct active LoRAs to use. + - 0: Use all max_loras (default behavior for has_lora=True). + - >0: Use exactly this many distinct LoRAs. + """ if num_sampled_tokens is None: num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32) @@ -140,13 +157,24 @@ class LoRAModelRunnerMixin: assert self.lora_manager is not None, "LoRA is not enabled" num_reqs = len(num_scheduled_tokens) - num_loras = lora_config.max_loras + max_loras = lora_config.max_loras + + # Determine how many distinct LoRAs to use + if not activate_lora: + # No LoRA active + effective_num_loras = 0 + elif num_active_loras > 0: + # Specific number of active LoRAs requested + effective_num_loras = min(num_active_loras, max_loras) + else: + # Default: use all max_loras + effective_num_loras = max_loras # Make prompt lora mapping # Assign LoRA IDs cyclically to simulate a worst-case scenario. - if activate_lora: + if effective_num_loras > 0: prompt_lora_mapping = ( - np.arange(num_reqs, dtype=np.int32) % num_loras + np.arange(num_reqs, dtype=np.int32) % effective_num_loras ) + 1 else: prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32) @@ -157,16 +185,19 @@ class LoRAModelRunnerMixin: # Make token lora mapping token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) - # Make dummy lora requests + # Make dummy lora requests (only for the active LoRAs) lora_requests: set[LoRARequest] = { LoRARequest( lora_name=f"warmup_{lora_id}", lora_int_id=lora_id, lora_path="/not/a/real/path", ) - for lora_id in range(1, num_loras + 1) + for lora_id in range(1, effective_num_loras + 1) } + # Always call _set_active_loras to ensure the mapping is updated. + # This is important when capturing no-LoRA graphs (effective_num_loras=0) + # after capturing LoRA graphs, as we need to clear the previous mapping. self._set_active_loras( tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests ) @@ -181,11 +212,27 @@ class LoRAModelRunnerMixin: num_sampled_tokens: np.ndarray, activate_lora: bool = True, remove_lora: bool = True, + num_active_loras: int = 0, ): + """ + Context manager for dummy runs with LoRA. + + Args: + lora_config: LoRA configuration. + num_scheduled_tokens: Array of scheduled token counts per request. + num_sampled_tokens: Array of sampled token counts per request. + activate_lora: Whether to activate LoRAs. + remove_lora: Whether to remove LoRAs after the context exits. + num_active_loras: Number of distinct active LoRAs to use. + """ with ( self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_select_dummy_loras( - lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora + lora_config, + num_scheduled_tokens, + num_sampled_tokens, + activate_lora, + num_active_loras, ), ): yield