From 764aa451403a691197a7f4c8a3ccc85e493de304 Mon Sep 17 00:00:00 2001 From: bk-201 Date: Fri, 19 Dec 2025 16:57:25 +0000 Subject: [PATCH] fix bug Signed-off-by: bk-201 --- tests/lora/test_lora_manager.py | 21 ++++++++++++----- vllm/lora/model_manager.py | 40 ++++++++++++++++++++++++--------- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index b4bb23a9e6b4f..d401db6fdde2a 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -18,7 +18,7 @@ from vllm.lora.layers import ( from vllm.lora.lora_model import LoRAModel from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.model_manager import ( - DEFAULT_WRAPPER_KEY, + DEFAULT_LANGUAGE_WRAPPER_KEY, LoRAMapping, LoRAModelManager, LRUCacheLoRAModelManager, @@ -185,7 +185,10 @@ def test_lora_model_manager(dist_init, dummy_model, device): assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 assert manager.device == device - assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device + assert ( + manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device + == device + ) assert hasattr(manager, "supported_lora_modules") assert sorted(manager.supported_lora_modules) == [ "dense1", @@ -278,7 +281,10 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): assert manager.remove_adapter(3) with pytest.raises(ValueError): assert manager.pin_adapter(3) - assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device + assert ( + manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device + == device + ) assert manager.device == device @@ -401,7 +407,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): assert manager.remove_oldest_adapter() assert set(manager.list_adapters()) == {1} - assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device + assert ( + manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device + == device + ) assert manager.device == device @@ -514,7 +523,7 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa assert worker_adapter_manager.device == device punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get( - DEFAULT_WRAPPER_KEY + DEFAULT_LANGUAGE_WRAPPER_KEY ) assert punica_wrapper.device == device @@ -621,7 +630,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path assert worker_adapter_manager.device == device punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get( - DEFAULT_WRAPPER_KEY + DEFAULT_LANGUAGE_WRAPPER_KEY ) assert punica_wrapper.device == device diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 3805e3c72f9f7..f685db14af17c 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math +from collections import OrderedDict from collections.abc import Callable from dataclasses import dataclass from typing import TypeVar @@ -119,7 +120,7 @@ class LoRAModelManager: def _init_punica_wrapper( self, max_num_batched_tokens: int, vllm_config: VllmConfig ) -> None: - self._lora_targets: list[tuple[str, PunicaWrapperBase]] = [] + self.punica_wrapper_mapping: OrderedDict[str, PunicaWrapperBase] = OrderedDict() llm_punica_wrapper = get_punica_wrapper( max_num_batched_tokens, max_batches=self.max_num_seqs, @@ -128,7 +129,9 @@ class LoRAModelManager: ) # NOTE This assumes the existence of a language model LoRA - self._lora_targets.append((DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper)) + self.punica_wrapper_mapping.setdefault( + DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper + ) self._maybe_init_mm(vllm_config) def _maybe_init_mm(self, vllm_config: VllmConfig) -> None: @@ -178,8 +181,10 @@ class LoRAModelManager: if self.supports_mm else DEFAULT_LANGUAGE_WRAPPER_KEY ) - _, llm_punica_wrapper = self._lora_targets.pop() - self._lora_targets.append((lm_prefix, llm_punica_wrapper)) + llm_punica_wrapper = self.punica_wrapper_mapping.pop( + DEFAULT_LANGUAGE_WRAPPER_KEY + ) + self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper # Tower wrappers tower_punica_wrapper = get_punica_wrapper( @@ -189,7 +194,7 @@ class LoRAModelManager: max_loras=self.lora_config.max_loras, ) for prefix in self.mm_mapping.tower_model: - self._lora_targets.append((prefix, tower_punica_wrapper)) + self.punica_wrapper_mapping[prefix] = tower_punica_wrapper # Use wrapper for connector if present. if self.mm_mapping.connector: @@ -204,7 +209,7 @@ class LoRAModelManager: max_loras=self.lora_config.max_loras, ) for prefix in self.mm_mapping.connector: - self._lora_targets.append((prefix, connector_punica_wrapper)) + self.punica_wrapper_mapping[prefix] = connector_punica_wrapper else: logger.warning_once( "Connector LoRA support disabled: model does not implement " @@ -213,8 +218,12 @@ class LoRAModelManager: ) # Longest-prefix-first - self._lora_targets = sorted( - self._lora_targets, key=lambda x: len(x[0]), reverse=True + self.punica_wrapper_mapping = OrderedDict( + sorted( + self.punica_wrapper_mapping.items(), + key=lambda x: len(x[0]), + reverse=True, + ) ) def __len__(self) -> int: @@ -569,9 +578,20 @@ class LoRAModelManager: """ Determine whether this module supports LoRA and which wrapper to use. """ - for prefix, wrapper in self._lora_targets: - if module_name.startswith(prefix): + # For language Model (early return) + if not self.supports_mm: + wrapper = list(self.punica_wrapper_mapping.values())[0] + return LoRATarget(wrapper=wrapper, prefix=DEFAULT_LANGUAGE_WRAPPER_KEY) + + # For multimodal model + for prefix, wrapper in self.punica_wrapper_mapping.items(): + is_language_model = ( + prefix == DEFAULT_LANGUAGE_WRAPPER_KEY + and module_name.startswith(self.mm_mapping.language_model[0]) + ) + if is_language_model or module_name.startswith(prefix): return LoRATarget(wrapper=wrapper, prefix=prefix) + return None def _register_packed_modules(self, module_full_name: str) -> None: