From f98e3075880bde73f5bdc20ab688e224352f6880 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 8 May 2025 00:17:12 +0800 Subject: [PATCH] [Bugfix] Fix missing lora name mapping for lora without prefix (#17793) Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/lora/test_utils.py | 69 +++++++++++++++++++++++++++++++++------- vllm/lora/utils.py | 6 ++-- 2 files changed, 61 insertions(+), 14 deletions(-) diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index 67f3866beff5..0d4e0bf681f2 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict +from typing import NamedTuple, Optional from unittest.mock import patch import pytest @@ -9,52 +10,96 @@ from torch import nn from vllm.lora.utils import (get_adapter_absolute_path, parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.models.utils import WeightsMapper + + +class LoRANameParserTestConfig(NamedTuple): + name: str + module_name: str + is_lora_a: bool + is_bias: bool + weights_mapper: Optional[WeightsMapper] = None def test_parse_fine_tuned_lora_name_valid(): - fixture = { - ("base_model.model.lm_head.lora_A.weight", "lm_head", True, False), - ("base_model.model.lm_head.lora_B.weight", "lm_head", False, False), - ( + fixture = [ + LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight", + "lm_head", True, False), + LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight", + "lm_head", False, False), + LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_A", "model.embed_tokens", True, False, ), - ( + LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_B", "model.embed_tokens", False, False, ), - ( + LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "model.layers.9.mlp.down_proj", True, False, ), - ( + LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "model.layers.9.mlp.down_proj", False, False, ), - ( + LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_A.weight", "language_model.layers.9.mlp.down_proj", True, False, ), - ( + LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_B.weight", "language_model.layers.9.mlp.down_proj", False, False, ), - } - for name, module_name, is_lora_a, is_bias in fixture: + # Test with WeightsMapper + LoRANameParserTestConfig( + "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", + "language_model.model.layers.9.mlp.down_proj", + True, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + LoRANameParserTestConfig( + "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", + "language_model.model.layers.9.mlp.down_proj", + False, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + LoRANameParserTestConfig( + "model.layers.9.mlp.down_proj.lora_A.weight", + "language_model.model.layers.9.mlp.down_proj", + True, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + LoRANameParserTestConfig( + "model.layers.9.mlp.down_proj.lora_B.weight", + "language_model.model.layers.9.mlp.down_proj", + False, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + ] + for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: assert (module_name, is_lora_a, - is_bias) == parse_fine_tuned_lora_name(name) + is_bias) == parse_fine_tuned_lora_name(name, weights_mapper) def test_parse_fine_tuned_lora_name_invalid(): diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 883ca938ea1a..01064e5d007e 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -117,16 +117,18 @@ def parse_fine_tuned_lora_name( # LoRA weight qualified name usually starts with `base_model.model.`, # so we remove the prefix `base_model.model.` to make the following # mapping correctly. - if "base_model.model." in name: + if name.startswith("base_model.model."): name = name.replace("base_model.model.", "") name = weights_mapper._map_name(name) if weights_mapper else name # recover the prefix `base_model.model.` name = "base_model.model." + name + else: + name = weights_mapper._map_name(name) if weights_mapper else name # In some situations, we may not start with `base_model.model.`. # If we don't (e.g., ibm-granite/granite-speech-3.3-8b), # we should keep the prefix intact. - start_index = 2 if "base_model.model." in name else 0 + start_index = 2 if name.startswith("base_model.model.") else 0 parts = name.split(".") if parts[-1] == "weight" and (parts[-2] == "lora_A"