From 1b2c440cd646dac290b535b86be89d22fbdbeab9 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 9 Oct 2025 14:47:14 +0800 Subject: [PATCH] [Core] Relax the LoRA max rank (#26461) Signed-off-by: Jee Jee Li --- vllm/config/lora.py | 2 +- vllm/v1/worker/lora_model_runner_mixin.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/config/lora.py b/vllm/config/lora.py index f97f2a111d417..60fb3605fda9a 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -103,7 +103,7 @@ class LoRAConfig: # Setting the maximum rank to 512 should be able to satisfy the vast # majority of applications. - possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) + possible_max_ranks = (1, 8, 16, 32, 64, 128, 256, 320, 512) possible_lora_extra_vocab_size = (256, 512) if self.max_lora_rank not in possible_max_ranks: raise ValueError( diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index e7358c4271cea..45b7a548d1843 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -28,8 +28,6 @@ logger = init_logger(__name__) # Defined as a mixin for GPUModelRunner class LoRAModelRunnerMixin: - LORA_WARMUP_RANK = 8 - def load_lora_model( self, model: nn.Module, vllm_config: VllmConfig, device: torch.device ) -> nn.Module: @@ -96,7 +94,9 @@ class LoRAModelRunnerMixin: assert self.lora_manager is not None, "LoRA is not enabled" num_loras = lora_config.max_loras - + lora_warmup_rank = ( + lora_config.max_lora_rank if lora_config.max_lora_rank < 8 else 8 + ) # Make dummy lora requests lora_requests: set[LoRARequest] = { LoRARequest( @@ -111,7 +111,7 @@ class LoRAModelRunnerMixin: # Add the dummy LoRAs here so _set_active_loras doesn't try to # load from disk. for lr in lora_requests: - self.lora_manager.add_dummy_lora(lr, rank=self.LORA_WARMUP_RANK) + self.lora_manager.add_dummy_lora(lr, rank=lora_warmup_rank) yield