diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 174b9f0b97760..22a45b60ca399 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -30,6 +30,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models import SupportsLoRA, supports_multimodal +from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper from vllm.utils import is_pin_memory_available @@ -104,6 +105,9 @@ class LoRAModel(AdapterModel): """Get LoRA for a given module by name""" return self.loras.get(module_name, None) + def check_lora_name(self, lora_name: str) -> bool: + return lora_name in self.loras + # (yard1): TODO see if we can derive target_embedding_padding automatically @classmethod def from_lora_tensors( @@ -335,6 +339,7 @@ class LoRAModelManager(AdapterModelManager): # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} super().__init__(model) + self.supported_lora_modules = get_supported_lora_modules(self.model) assert self.supported_lora_modules, "No supported LoRA modules found in" f"{self.model.__class__.__name__}." @@ -350,6 +355,7 @@ class LoRAModelManager(AdapterModelManager): # In case the model only supports LoRA for # text modules (e.g. ChatGLM) and hasattr(self.model, "get_mm_mapping")) + self.is_pooling_model = is_pooling_model(self.model) self.packed_modules: Dict[str, List[str]] = {} self.modules: Dict[str, BaseLayerWithLoRA] = {} # Dict instead of a Set for compatibility with LRUCache. @@ -389,7 +395,7 @@ class LoRAModelManager(AdapterModelManager): lora_model.id, index) self.lora_index_to_id[index] = lora_model.id for module_name, module in self.modules.items(): - module_lora = lora_model.get_lora(module_name) + module_lora = self._get_lora_layer_weights(lora_model, module_name) if module_lora: module_lora.optimize() # Bias is not explicitly enabled with the flag enable_lora_bias. @@ -626,7 +632,7 @@ class LoRAModelManager(AdapterModelManager): replaced_module: Set[str] = set() has_replacement = False for r in new_module_names: - lora = lora_model.get_lora(r) + lora = self._get_lora_layer_weights(lora_model, r) replacement_loras.append(lora) if lora: has_replacement = True @@ -637,12 +643,34 @@ class LoRAModelManager(AdapterModelManager): if replacement_loras[i]: continue replacement_loras[i] = None + # HACK Temporary solution for the pool model. + if self.is_pooling_model and not lora_model.check_lora_name( + module_name): + replaced_module_name = module_name.replace("model.", "") + if lora_model.check_lora_name(module_name): + module_name = replaced_module_name lora_model.loras[module_name] = PackedLoRALayerWeights.pack( replacement_loras) # Remove the modules that have been replaced. for module in replaced_module: lora_model.loras.pop(module, None) + def _get_lora_layer_weights( + self, lora_model: LoRAModel, + module_name: str) -> Optional[LoRALayerWeights]: + org_module_name = module_name + if self.is_pooling_model and not lora_model.check_lora_name( + module_name): + # If it's a pool model, and the layer name is not found, + # remove the prefix 'model.' and search again. + module_name = module_name.replace("model.", "") + if lora_model.check_lora_name(module_name): + org_module_name = module_name + logger.info_once( + "For the pool model, successfully loaded the LoRA weights " + "after removing the prefix 'model.'.") + return lora_model.get_lora(org_module_name) + def deactivate_adapter(self, adapter_id: int) -> bool: return deactivate_adapter(adapter_id, self._active_adapters, self._deactivate_adapter)