[LoRA] Much faster startup when LoRA is enabled (#23777)

Signed-off-by: Andy Lo <andy@mistral.ai>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Andy Lo 2025-08-30 16:37:39 +01:00 committed by GitHub
parent 68a349114f
commit 038e9be4eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 33 additions and 13 deletions

View File

@ -2213,6 +2213,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode: bool = False,
skip_eplb: bool = False,
is_profile: bool = False,
remove_lora: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
@ -2230,6 +2231,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode: If True, the batch is a uniform decode batch.
skip_eplb: If True, skip EPLB state update.
is_profile: If True, this is a profile run.
remove_lora: If False, dummy LoRAs are not destroyed after the run
"""
assert cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
@ -2317,7 +2319,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
num_scheduled_tokens, remove_lora):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
@ -2708,11 +2710,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
skip_eplb=True)
skip_eplb=True,
remove_lora=False)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
skip_eplb=True)
skip_eplb=True,
remove_lora=False)
self.maybe_remove_all_loras(self.lora_config)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""

View File

@ -308,7 +308,10 @@ class Worker(WorkerBase):
# We skip EPLB here since we don't want to record dummy metrics
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size, skip_eplb=True)
self.model_runner._dummy_run(size,
skip_eplb=True,
remove_lora=False)
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
# Warmup and tune the kernels used during model execution before
# cuda graph capture.

View File

@ -5,7 +5,7 @@ Define LoRA functionality mixin for model runners.
"""
from contextlib import contextmanager
from typing import Union
from typing import Optional, Union
import numpy as np
import torch
@ -87,7 +87,9 @@ class LoRAModelRunnerMixin:
lora_requests)
@contextmanager
def maybe_setup_dummy_loras(self, lora_config):
def maybe_setup_dummy_loras(self,
lora_config: Optional[LoRAConfig],
remove_lora: bool = True):
if lora_config is None:
yield
else:
@ -114,10 +116,11 @@ class LoRAModelRunnerMixin:
yield
# __exit__ code
self.lora_manager.remove_all_adapters()
if remove_lora:
self.lora_manager.remove_all_adapters()
@contextmanager
def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig],
num_scheduled_tokens: np.ndarray):
if lora_config is None:
yield
@ -151,13 +154,22 @@ class LoRAModelRunnerMixin:
yield
@contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
with self.maybe_setup_dummy_loras(
lora_config), self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens):
def maybe_dummy_run_with_lora(self,
lora_config: Optional[LoRAConfig],
num_scheduled_tokens: np.ndarray,
remove_lora: bool = True):
with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(lora_config,
num_scheduled_tokens),
):
yield
def maybe_remove_all_loras(self, lora_config: Optional[LoRAConfig]):
if lora_config is None:
return
self.lora_manager.remove_all_adapters()
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")