From 038e9be4eb7a63189c8980845d80cb96957b9919 Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Sat, 30 Aug 2025 16:37:39 +0100 Subject: [PATCH] [LoRA] Much faster startup when LoRA is enabled (#23777) Signed-off-by: Andy Lo Co-authored-by: Jee Jee Li --- vllm/v1/worker/gpu_model_runner.py | 11 ++++++--- vllm/v1/worker/gpu_worker.py | 5 +++- vllm/v1/worker/lora_model_runner_mixin.py | 30 ++++++++++++++++------- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c6d50c17f2b4d..d6717892d4aec 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2213,6 +2213,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): uniform_decode: bool = False, skip_eplb: bool = False, is_profile: bool = False, + remove_lora: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -2230,6 +2231,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): uniform_decode: If True, the batch is a uniform decode batch. skip_eplb: If True, skip EPLB state update. is_profile: If True, this is a profile run. + remove_lora: If False, dummy LoRAs are not destroyed after the run """ assert cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL @@ -2317,7 +2319,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens): + num_scheduled_tokens, remove_lora): if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -2708,11 +2710,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, uniform_decode=uniform_decode, - skip_eplb=True) + skip_eplb=True, + remove_lora=False) self._dummy_run(num_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, uniform_decode=uniform_decode, - skip_eplb=True) + skip_eplb=True, + remove_lora=False) + self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 2088bfff5bb39..2e7d6685377f2 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -308,7 +308,10 @@ class Worker(WorkerBase): # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size, skip_eplb=True) + self.model_runner._dummy_run(size, + skip_eplb=True, + remove_lora=False) + self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config) # Warmup and tune the kernels used during model execution before # cuda graph capture. diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 84ed46989ea97..4b5f27d27541b 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -5,7 +5,7 @@ Define LoRA functionality mixin for model runners. """ from contextlib import contextmanager -from typing import Union +from typing import Optional, Union import numpy as np import torch @@ -87,7 +87,9 @@ class LoRAModelRunnerMixin: lora_requests) @contextmanager - def maybe_setup_dummy_loras(self, lora_config): + def maybe_setup_dummy_loras(self, + lora_config: Optional[LoRAConfig], + remove_lora: bool = True): if lora_config is None: yield else: @@ -114,10 +116,11 @@ class LoRAModelRunnerMixin: yield # __exit__ code - self.lora_manager.remove_all_adapters() + if remove_lora: + self.lora_manager.remove_all_adapters() @contextmanager - def maybe_select_dummy_loras(self, lora_config: LoRAConfig, + def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray): if lora_config is None: yield @@ -151,13 +154,22 @@ class LoRAModelRunnerMixin: yield @contextmanager - def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, - num_scheduled_tokens: np.ndarray): - with self.maybe_setup_dummy_loras( - lora_config), self.maybe_select_dummy_loras( - lora_config, num_scheduled_tokens): + def maybe_dummy_run_with_lora(self, + lora_config: Optional[LoRAConfig], + num_scheduled_tokens: np.ndarray, + remove_lora: bool = True): + with ( + self.maybe_setup_dummy_loras(lora_config, remove_lora), + self.maybe_select_dummy_loras(lora_config, + num_scheduled_tokens), + ): yield + def maybe_remove_all_loras(self, lora_config: Optional[LoRAConfig]): + if lora_config is None: + return + self.lora_manager.remove_all_adapters() + def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.")