diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index fc59f433f3632..209dc85b5bbc9 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -109,14 +109,6 @@ class CudagraphDispatcher: max_loras = self.vllm_config.lora_config.max_loras - # When speculative decoding is enabled, only capture with max_loras - # to avoid torch.compile conflicts during CUDA graph capture - if self.vllm_config.speculative_config is not None: - lora_cases = [(True, max_loras)] - if self.compilation_config.cudagraph_specialize_lora: - lora_cases.append((False, 0)) - return lora_cases - # Capture for each num_active_loras from 1 to max_loras lora_cases = [(True, n) for n in range(1, max_loras + 1)] # Also capture the no-lora case diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 673eb5e6df2d2..fabb5f07ba683 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4598,24 +4598,6 @@ class GPUModelRunner( self.encoder_cache.clear() gc.collect() - def _get_lora_capture_cases(self) -> list[tuple[bool, int]]: - """ - Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture. - - Returns cases for each num_active_loras from 1 to max_loras. - If cudagraph_specialize_lora is True, also includes the no-lora case. - """ - if not self.lora_config: - return [(False, 0)] - - max_loras = self.lora_config.max_loras - # Capture for each num_active_loras from 1 to max_loras - lora_cases = [(True, n) for n in range(1, max_loras + 1)] - # Also capture the no-lora case if cudagraph_specialize_lora is True - if self.compilation_config.cudagraph_specialize_lora: - lora_cases.append((False, 0)) - return lora_cases - def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( @@ -4654,7 +4636,7 @@ class GPUModelRunner( assert cudagraph_mode is not None # Build LoRA cases: list of (has_lora, num_active_loras) tuples - lora_cases = self._get_lora_capture_cases() + lora_cases = self.cudagraph_dispatcher._get_lora_cases() if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode()