diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 209dc85b5bbc9..7fe11abee30d8 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -58,6 +58,7 @@ class CudagraphDispatcher: self.keys_initialized = False self.specialize_lora_count = False + self.specialize_lora_count = False def _create_padded_batch_descriptor( self, @@ -65,6 +66,11 @@ class CudagraphDispatcher: uniform_decode: bool, has_lora: bool, num_active_loras: int = 0, + self, + num_tokens: int, + uniform_decode: bool, + has_lora: bool, + num_active_loras: int = 0, ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs uniform_decode_query_len = self.uniform_decode_query_len @@ -83,6 +89,7 @@ class CudagraphDispatcher: uniform=uniform_decode, has_lora=has_lora, num_active_loras=num_active_loras, + num_active_loras=num_active_loras, ) def add_cudagraph_key( @@ -93,29 +100,6 @@ class CudagraphDispatcher: ) self.cudagraph_keys[runtime_mode].add(batch_descriptor) - def _get_lora_cases(self) -> list[tuple[bool, int]]: - """ - Returns list of (has_lora, num_active_loras) tuples for graph capture. - - Returns cases for each num_active_loras from 1 to max_loras. - If cudagraph_specialize_lora is True, also includes the no-lora case. - - Note: When speculative decoding is enabled, we fall back to capturing - only with max_loras to avoid conflicts with torch.compile during - CUDA graph capture. - """ - if not self.vllm_config.lora_config: - return [(False, 0)] - - max_loras = self.vllm_config.lora_config.max_loras - - # Capture for each num_active_loras from 1 to max_loras - lora_cases = [(True, n) for n in range(1, max_loras + 1)] - # Also capture the no-lora case - if self.compilation_config.cudagraph_specialize_lora: - lora_cases.append((False, 0)) - return lora_cases - def initialize_cudagraph_keys( self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int ): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fabb5f07ba683..d8794e8e4fcaa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4598,6 +4598,24 @@ class GPUModelRunner( self.encoder_cache.clear() gc.collect() + def _get_lora_capture_cases(self) -> list[tuple[bool, int]]: + """ + Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture. + + Returns cases for each num_active_loras from 1 to max_loras. + If cudagraph_specialize_lora is True, also includes the no-lora case. + """ + if not self.lora_config: + return [(False, 0)] + + max_loras = self.lora_config.max_loras + # Capture for each num_active_loras from 1 to max_loras + lora_cases = [(True, n) for n in range(1, max_loras + 1)] + # Also capture the no-lora case if cudagraph_specialize_lora is True + if self.compilation_config.cudagraph_specialize_lora: + lora_cases.append((False, 0)) + return lora_cases + def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning(