diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index a2a498c4067c3..de2d70a7fbbdc 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -240,7 +240,7 @@ def _lora_expand( # Each LoRA receives its own set of thread blocks for output # computation. If some LoRA doesn't have any tokens to process, its # thread blocks simply exit. - MAX_LORAS, + num_active_loras, ) use_gdc = supports_pdl(inputs.device) _lora_expand_kernel[grid]( diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 240117b193506..912c546ace8b0 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -220,7 +220,7 @@ def _lora_shrink( # Each LoRA receives its own set of thread blocks for output # computation. If some LoRA doesn't have any tokens to process, its # thread blocks exit early. - MAX_LORAS, + num_active_loras, ) use_gdc = supports_pdl(inputs.device) _lora_shrink_kernel[grid]( diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 7fe11abee30d8..9d731bde644d3 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -58,7 +58,25 @@ class CudagraphDispatcher: self.keys_initialized = False self.specialize_lora_count = False - self.specialize_lora_count = False + + def _get_lora_cases(self) -> list[tuple[bool, int]]: + """ + Returns list of (has_lora, num_active_loras) tuples for CUDA graph + capture. This is the single source of truth for LoRA capture cases. + """ + lora_config = self.vllm_config.lora_config + if lora_config is None: + # No LoRA configured - single case with no LoRA + return [(False, 0)] + + # LoRA is enabled - capture graphs for different active LoRA counts + # Always include the no-LoRA case (for requests without adapters) + cases: list[tuple[bool, int]] = [(False, 0)] + + for n in range(1, lora_config.max_loras + 1): + cases.append((True, n)) + + return cases def _create_padded_batch_descriptor( self, @@ -66,11 +84,6 @@ class CudagraphDispatcher: uniform_decode: bool, has_lora: bool, num_active_loras: int = 0, - self, - num_tokens: int, - uniform_decode: bool, - has_lora: bool, - num_active_loras: int = 0, ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs uniform_decode_query_len = self.uniform_decode_query_len @@ -89,7 +102,6 @@ class CudagraphDispatcher: uniform=uniform_decode, has_lora=has_lora, num_active_loras=num_active_loras, - num_active_loras=num_active_loras, ) def add_cudagraph_key( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d8794e8e4fcaa..fabb5f07ba683 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4598,24 +4598,6 @@ class GPUModelRunner( self.encoder_cache.clear() gc.collect() - def _get_lora_capture_cases(self) -> list[tuple[bool, int]]: - """ - Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture. - - Returns cases for each num_active_loras from 1 to max_loras. - If cudagraph_specialize_lora is True, also includes the no-lora case. - """ - if not self.lora_config: - return [(False, 0)] - - max_loras = self.lora_config.max_loras - # Capture for each num_active_loras from 1 to max_loras - lora_cases = [(True, n) for n in range(1, max_loras + 1)] - # Also capture the no-lora case if cudagraph_specialize_lora is True - if self.compilation_config.cudagraph_specialize_lora: - lora_cases.append((False, 0)) - return lora_cases - def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning(