diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index c9297d0071f13..3805e3c72f9f7 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -3,6 +3,7 @@ import math from collections.abc import Callable +from dataclasses import dataclass from typing import TypeVar import regex as re @@ -42,7 +43,13 @@ from vllm.v1.worker.utils import MultiModalBudget logger = init_logger(__name__) T = TypeVar("T") -DEFAULT_WRAPPER_KEY = "language_model" +DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model" + + +@dataclass(frozen=True) +class LoRATarget: + wrapper: PunicaWrapperBase + prefix: str class AdapterLRUCache(LRUCache[int, T]): @@ -112,17 +119,16 @@ class LoRAModelManager: def _init_punica_wrapper( self, max_num_batched_tokens: int, vllm_config: VllmConfig ) -> None: + self._lora_targets: list[tuple[str, PunicaWrapperBase]] = [] llm_punica_wrapper = get_punica_wrapper( max_num_batched_tokens, max_batches=self.max_num_seqs, device=self.device, max_loras=self.lora_config.max_loras, ) - # NOTE This assumes the existence of a language model LoRA - self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = { - DEFAULT_WRAPPER_KEY: llm_punica_wrapper - } + # NOTE This assumes the existence of a language model LoRA + self._lora_targets.append((DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper)) self._maybe_init_mm(vllm_config) def _maybe_init_mm(self, vllm_config: VllmConfig) -> None: @@ -164,15 +170,27 @@ class LoRAModelManager: ) # Only one language model can be included in the model. - assert len(self.mm_mapping.language_model == 1) + assert len(self.mm_mapping.language_model) == 1 + + # Update prefix of language model + lm_prefix = ( + self.mm_mapping.language_model[0] + if self.supports_mm + else DEFAULT_LANGUAGE_WRAPPER_KEY + ) + _, llm_punica_wrapper = self._lora_targets.pop() + self._lora_targets.append((lm_prefix, llm_punica_wrapper)) + # Tower wrappers - for name in self.mm_mapping.tower_model: - self.punica_wrapper_mapping[name] = get_punica_wrapper( - num_encoder_tokens, - max_batches=self.max_num_seqs * limit_per_prompt, - device=self.device, - max_loras=self.lora_config.max_loras, - ) + tower_punica_wrapper = get_punica_wrapper( + num_encoder_tokens, + max_batches=self.max_num_seqs * limit_per_prompt, + device=self.device, + max_loras=self.lora_config.max_loras, + ) + for prefix in self.mm_mapping.tower_model: + self._lora_targets.append((prefix, tower_punica_wrapper)) + # Use wrapper for connector if present. if self.mm_mapping.connector: if hasattr(self.info, "get_num_mm_connector_tokens"): @@ -185,12 +203,8 @@ class LoRAModelManager: device=self.device, max_loras=self.lora_config.max_loras, ) - self.punica_wrapper_mapping.update( - { - name: connector_punica_wrapper - for name in self.mm_mapping.connector - } - ) + for prefix in self.mm_mapping.connector: + self._lora_targets.append((prefix, connector_punica_wrapper)) else: logger.warning_once( "Connector LoRA support disabled: model does not implement " @@ -198,6 +212,11 @@ class LoRAModelManager: "determine the connector's token budget for LoRA operations." ) + # Longest-prefix-first + self._lora_targets = sorted( + self._lora_targets, key=lambda x: len(x[0]), reverse=True + ) + def __len__(self) -> int: return len(self._registered_adapters) @@ -326,20 +345,22 @@ class LoRAModelManager: def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: # Default to the main language model wrapper if not (self.supports_mm and self.supports_tower_connector_lora): - target_wrapper = self.punica_wrapper_mapping[DEFAULT_WRAPPER_KEY] + target_prefix = ( + self.mm_mapping.language_model[0] + if self.supports_mm + else DEFAULT_LANGUAGE_WRAPPER_KEY + ) + elif mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model: + target_prefix = self.mm_mapping.tower_model[0] + elif mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector: + target_prefix = self.mm_mapping.connector[0] else: - if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model: - target_prefix = self.mm_mapping.tower_model[0] - elif ( - mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector - ): - target_prefix = self.mm_mapping.connector[0] - else: - target_prefix = self.mm_mapping.language_model[0] + target_prefix = self.mm_mapping.language_model[0] - target_wrapper = self.punica_wrapper_mapping[target_prefix] + target = self._get_lora_target(target_prefix) + assert target is not None - target_wrapper.update_metadata( + target.wrapper.update_metadata( mapping, self.lora_index_to_id, self.lora_slots + 1, @@ -367,7 +388,8 @@ class LoRAModelManager: if not self._match_target_modules(module_name): continue - if self._filter_unsupported_mm_module(module_name): + target = self._get_lora_target(module_name) + if target is None: logger.warning( "Regarding %s, vLLM currently only supports adding LoRA to" " language model, %s will be ignored.", @@ -433,10 +455,7 @@ class LoRAModelManager: self._register_packed_modules(module_name) # All lora layers share the same punica_wrapper based on reference. - wrapper = self._get_punica_wrapper_for_module(module_name) - if wrapper is None: - continue - new_module.set_mapping(wrapper) + new_module.set_mapping(target.wrapper) def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA), ( @@ -457,7 +476,7 @@ class LoRAModelManager: if ( not self._match_target_modules(module_name) or not isinstance(module, BaseLayerWithLoRA) - or self._filter_unsupported_mm_module(module_name) + or self._get_lora_target(module_name) is None ): continue parts = module_name.split(".") @@ -546,42 +565,14 @@ class LoRAModelManager: for target_module in self.supported_lora_modules ) - def _filter_unsupported_mm_module(self, module_name: str) -> bool: + def _get_lora_target(self, module_name: str) -> LoRATarget | None: """ - Regarding multimodal models, vLLM currently only supports adding LoRA to - language model. LoRA for other modules, such as the vision tower, will - be filtered out. + Determine whether this module supports LoRA and which wrapper to use. """ - if not self.supports_mm: - return False - - if self.supports_tower_connector_lora: - return self._get_punica_wrapper_for_module(module_name) is None - - prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model - return any(module_name.startswith(prefix) for prefix in prefix_lst) - - def _get_punica_wrapper_for_module( - self, module_name: str - ) -> PunicaWrapperBase | None: - """ - Match the corresponding punica_wrapper based on module_name, - and return None if lora is not supported for this module. - """ - best_prefix = None - for prefix in self.punica_wrapper_mapping: - if prefix == DEFAULT_WRAPPER_KEY: - continue - # Ensure matching by the longest prefix. - if module_name.startswith(prefix) and ( - best_prefix is None or len(prefix) > len(best_prefix) - ): - best_prefix = prefix - - if best_prefix is not None: - return self.punica_wrapper_mapping[best_prefix] - - return self.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY) + for prefix, wrapper in self._lora_targets: + if module_name.startswith(prefix): + return LoRATarget(wrapper=wrapper, prefix=prefix) + return None def _register_packed_modules(self, module_full_name: str) -> None: parts = module_full_name.split(".")