diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 521bb079da41b..633674d5fb293 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -498,6 +498,14 @@ class LoRAModelManager(AdapterModelManager): self._active_adapters.clear() def _create_lora_modules(self): + + def _parent_module(module_name: str) -> str: + # module name is a dot separated name. + # for example: + # - given an input 'x.y.z' return 'x.y' + # - given an input 'x' return '' + return module_name.rpartition('.')[0] + for module_name, module in self.model.named_modules( remove_duplicate=False): if isinstance(module, PPMissingLayer): @@ -529,10 +537,17 @@ class LoRAModelManager(AdapterModelManager): new_module.scaling_factor_to_offset # (yard1): TODO make this more robust if "lm_head" in module_name: + logits_processor_module_name = 'logits_processor' + parent_module = _parent_module(module_name) + if parent_module: + logits_processor_module_name = ( + f"{parent_module}.{logits_processor_module_name}") + logits_processor_module = self.model.get_submodule( - "logits_processor") + logits_processor_module_name) + new_module = replace_submodule( - self.model, "logits_processor", + self.model, logits_processor_module_name, from_layer_logits_processor(logits_processor_module, module, self.lora_slots, self.lora_config, diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 6b3291e9c92fa..7148ffe14948e 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -188,16 +188,20 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]: """ In vLLM, all linear layers support LoRA. """ + supported_lora_modules: set[str] = set() - # step1: traverse the model to get all the linear subfixes. for name, module in model.named_modules(): + # get the embedding modules if the module's embedding_modules + # is not empty. + embedding_modules = getattr(module, "embedding_modules", None) + if embedding_modules is not None: + for name in embedding_modules: + supported_lora_modules.add(name) + + # get all the linear subfixes. if isinstance(module, (LinearBase, )): supported_lora_modules.add(name.split(".")[-1]) - # step 2: get the embedding modules if the model's mbedding_modules - # is not empty. - if model.embedding_modules: - for name in model.embedding_modules: - supported_lora_modules.add(name) + return list(supported_lora_modules)