diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 50f17ced5dd74..b4bb23a9e6b4f 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -18,6 +18,7 @@ from vllm.lora.layers import ( from vllm.lora.lora_model import LoRAModel from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.model_manager import ( + DEFAULT_WRAPPER_KEY, LoRAMapping, LoRAModelManager, LRUCacheLoRAModelManager, @@ -183,9 +184,8 @@ def test_lora_model_manager(dist_init, dummy_model, device): assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 - assert manager.device == device - assert manager.punica_wrapper.device == device + assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device assert hasattr(manager, "supported_lora_modules") assert sorted(manager.supported_lora_modules) == [ "dense1", @@ -278,8 +278,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): assert manager.remove_adapter(3) with pytest.raises(ValueError): assert manager.pin_adapter(3) - - assert manager.punica_wrapper.device == device + assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device assert manager.device == device @@ -402,7 +401,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): assert manager.remove_oldest_adapter() assert set(manager.list_adapters()) == {1} - assert manager.punica_wrapper.device == device + assert manager.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY).device == device assert manager.device == device @@ -514,7 +513,10 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa ) assert worker_adapter_manager.device == device - assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device + punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get( + DEFAULT_WRAPPER_KEY + ) + assert punica_wrapper.device == device @pytest.mark.parametrize("device", DEVICES) @@ -618,7 +620,10 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path ) assert worker_adapter_manager.device == device - assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device + punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get( + DEFAULT_WRAPPER_KEY + ) + assert punica_wrapper.device == device @pytest.mark.parametrize("device", DEVICES) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 33e147195a6f7..c9297d0071f13 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -42,7 +42,7 @@ from vllm.v1.worker.utils import MultiModalBudget logger = init_logger(__name__) T = TypeVar("T") -DEFAULT_WRAPPER_KEY = "__default__" +DEFAULT_WRAPPER_KEY = "language_model" class AdapterLRUCache(LRUCache[int, T]): @@ -112,15 +112,15 @@ class LoRAModelManager: def _init_punica_wrapper( self, max_num_batched_tokens: int, vllm_config: VllmConfig ) -> None: - self.punica_wrapper = get_punica_wrapper( + llm_punica_wrapper = get_punica_wrapper( max_num_batched_tokens, max_batches=self.max_num_seqs, device=self.device, max_loras=self.lora_config.max_loras, ) - + # NOTE This assumes the existence of a language model LoRA self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = { - DEFAULT_WRAPPER_KEY: self.punica_wrapper + DEFAULT_WRAPPER_KEY: llm_punica_wrapper } self._maybe_init_mm(vllm_config) @@ -163,8 +163,8 @@ class LoRAModelManager: mm_budget.get_encoder_budget() ) - self.punica_wrapper_mapping = {} - + # Only one language model can be included in the model. + assert len(self.mm_mapping.language_model == 1) # Tower wrappers for name in self.mm_mapping.tower_model: self.punica_wrapper_mapping[name] = get_punica_wrapper( @@ -173,12 +173,6 @@ class LoRAModelManager: device=self.device, max_loras=self.lora_config.max_loras, ) - - # Language wrapper - self.punica_wrapper_mapping[self.mm_mapping.language_model[0]] = ( - self.punica_wrapper - ) - # Use wrapper for connector if present. if self.mm_mapping.connector: if hasattr(self.info, "get_num_mm_connector_tokens"):