From cb3e73e4c8142b5ce8ac34efc2fa04d90f142dc5 Mon Sep 17 00:00:00 2001 From: fade_away <1028552010@qq.com> Date: Sat, 1 Feb 2025 12:52:07 +0800 Subject: [PATCH] [BugFix] fix wrong output when using lora and num_scheduler_steps=8 (#11161) FIX issue https://github.com/vllm-project/vllm/issues/9688 https://github.com/vllm-project/vllm/issues/11086 #12487 --------- Signed-off-by: Jee Jee Li Co-authored-by: weilong.yu Co-authored-by: Jee Jee Li --- vllm/worker/model_runner.py | 4 ++++ vllm/worker/worker.py | 3 --- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 160c0662ce976..322d91d62ce46 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1346,6 +1346,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() + if self.lora_config: + # Remove dummy loras. + assert self.lora_manager is not None + self.remove_all_loras() return def remove_all_loras(self): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 24bba79fedd75..1d2884d3ddf51 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -264,10 +264,7 @@ class Worker(LocalOrDistributedWorkerBase): f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.") logger.info(msg) - # Final cleanup - if self.model_runner.lora_manager: - self.model_runner.remove_all_loras() gc.collect() return num_gpu_blocks, num_cpu_blocks